diff --git a/CHANGELOG.md b/CHANGELOG.md index 9bc51a129..b97d4ccb8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,30 @@ # CUTLASS 4.x +## [4.5.0](https://github.com/NVIDIA/cutlass/tree/main) (2026-03-27) + +### CuTe DSL +* New features + - Auto-deduced smem size for launching kernels + - Launch config `smem` now defaults to `None` for auto-calculating kernel shared memory usage, which is recommended unless manual control is required. + - Warnings will be raised when the manually set shared memory size is insufficient or exceeds the GPU maximum. + - The default shared memory usage calculation aligns with CUDA C++ static shared memory behavior, i.e. summing all allocations additively. + - An additional launch option `smem_merge_branch_allocs` is provided to merge shared memory allocations across mutually exclusive code branches, which is recommended for inlined mega-kernels to reduce total footprint. + +* Bug fixing and improvements + - Improved source code correlation for profiling/debugging + +### CUTLASS C++ +* Add [example 95](https://github.com/NVIDIA/cutlass/tree/main/examples/95_blackwell_gemm_green_context) to support green context SM partition + - Enables launching GEMM on stream with partial SM allocation. +* Fix some kernel issues: + - Fix l2_capacity=0 handling in Blackwell SM100/SM120 kernel templates + - Fix CUTLASS clang build issues +* Fix some profiler issues: + - Add missing reference kernels for blockwise GEMM profiler +* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! +* Optimal code generation with CUDA toolkit versions 13.2. + ## [4.4.2](https://github.com/NVIDIA/cutlass/releases/tag/v4.4.2) (2026-03-13) ### CuTe DSL @@ -30,7 +54,7 @@ + Set up with cutlass/python/CuTeDSL/setup.sh --cu13 + Refer to https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/quick_start.html for more details - GB300 is now supported in CuTe DSL with CTK 13.1 - + Refer to [SM103 batched 3xFP4 blockscaled GEMM kernel](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/sm103_dense_blockscaled_gemm_persistent.py) for example kernel + + Refer to [SM103 batched FP4 Ultra blockscaled GEMM kernel](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/sm103_dense_blockscaled_gemm_persistent.py) for example kernel - cute.experimental: introduce a higher-level, composable layer on top of existing CuTe DSL APIs (not a separate abstraction), which can be mixed with existing Cute DSL building blocks. + Fragment-free programming model: copy/dot APIs take memrefs directly instead of descriptors/fragments. + Automatic TMA descriptor generation and update insertion. @@ -53,7 +77,7 @@ - It is possible now to have customized epilogue fusion for persistent dense GEMM through a Python Epilogue Fusion Configuration (EFC) function, somewhat similar to CUTLASS C++ EVT. It also provides a PyTorch evaluator to compare the results. * More examples of authorizing peak-performance kernels - - [SM103 batched 3xFP4 blockscaled GEMM kernel](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/sm103_dense_blockscaled_gemm_persistent.py) + - [SM103 batched FP4 Ultra blockscaled GEMM kernel](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/sm103_dense_blockscaled_gemm_persistent.py) - Mixed input FMHA decode example with support for int4 KV (int8 KV supported in 4.3) - New acc_scale grouped mixed input gemm kernel variant is introduced to deliver better performance for decoding cases. - All mixed_input_gemm examples are moved into a separate folder `mixed_input_gemm`. Common utility functions are also extracted into mixed_input_host_utils.py under the same folder. diff --git a/CMakeLists.txt b/CMakeLists.txt index 0c9c35d34..2b5a29365 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -96,8 +96,6 @@ endif() if (CUTLASS_CLANG_DEVICE_COMPILE AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0) message(FATAL_ERROR "Clang 7.0+ required for GPU compilation") endif() -find_package(Doxygen QUIET) - ################################################################################ # @@ -789,31 +787,6 @@ install( ################################################################################ -# Doxygen is available. Generate documentation -if (DOXYGEN_FOUND) - # DOT is available. Enable graph generation in the documentation - if (DOXYGEN_DOT_EXECUTABLE) - set(CUTLASS_ENABLE_DOXYGEN_DOT ON CACHE BOOL "Use dot to generate graphs in the doxygen documentation.") - else() - set(CUTLASS_ENABLE_DOXYGEN_DOT OFF CACHE BOOL "Use dot to generate graphs in the doxygen documentation." FORCE) - endif() - - if (CUTLASS_ENABLE_DOXYGEN_DOT) - set(HAVE_DOT "YES") - else() - set(HAVE_DOT "NO") - endif() - - # Add custom target for Doxygen. - add_custom_target(cutlass_docs ${CMAKE_COMMAND} -E env - "DOT_PATH=${DOXYGEN_DOT_EXECUTABLE}" - "HAVE_DOT=${HAVE_DOT}" - ${DOXYGEN_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/Doxyfile - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} - VERBATIM - ) -endif() - if(NOT WIN32) # Add common library search paths so executables and libraries can load and run # without LD_LIBRARY_PATH being set. diff --git a/README.md b/README.md index 122369cb5..af62b1727 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,9 @@ ![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") # Overview -# CUTLASS 4.4.2 +# CUTLASS 4.5.0 -_CUTLASS 4.4.2 - March 2026_ +_CUTLASS 4.5.0 - March 2026_ CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM) and related computations at all levels and scales within CUDA. It incorporates strategies for @@ -43,117 +43,29 @@ To get started quickly - please refer : - [CUTLASS C++ Quick Start Guide](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/quickstart.html). - [CuTe DSL Quick Start Guide](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/quick_start.html). -# What's New in CUTLASS 4.4 +# What's New in CUTLASS 4.5 -## CuTe DSL +### CuTe DSL * New features - - CuTe DSL now supports CUDA toolkit 13.1! - + Set up with cutlass/python/CuTeDSL/setup.sh --cu13 - + Refer to https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/quick_start.html for more details - - GB300 is now supported in CuTe DSL with CTK 13.1 - + Refer to [SM103 batched 3xFP4 blockscaled GEMM kernel](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/sm103_dense_blockscaled_gemm_persistent.py) for example kernel - - cute.experimental: introduce a higher-level, composable layer on top of existing CuTe DSL APIs (not a separate abstraction), which can be mixed with existing Cute DSL building blocks. - + Fragment-free programming model: copy/dot APIs take memrefs directly instead of descriptors/fragments. - + Automatic TMA descriptor generation and update insertion. - + Automatic vectorization and predication for SIMT copies. - + New pipeline abstraction with convenience wrappers - + New Partition ops to simplify partitioning logic. - + Device-side TMA descriptor allocation, initialization, and management - + These examples can be found here https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/experimental - - Ahead of Time (AoT) compilation is now available! - + Refer to files under https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/cute/export for example usage - - JAX support - you can now use CuTeDSL along with JAX - + Refer to files under https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/jax for example usage - - Introduced versioning support in DSL: - + cutlass.__version__ for a string representation of DSL version - + cutlass.CUDA_VERSION for a version class to tell the CUDA version used for DSL - - Added CopyDsmemStoreOp to store data to distributed shared memory with explicit synchronization. - - Grouped GEMM example now supports device-only problem shapes. - - We allow grid carve-out without problem shapes being available on host. - - Tma+LdMatrix features for loading+unpacking narrow-width types (refer to mixed_input_fmha_decode.py for example usage). - - It is possible now to have customized epilogue fusion for persistent dense GEMM through a Python Epilogue Fusion Configuration (EFC) function, somewhat similar to CUTLASS C++ EVT. It also provides a PyTorch evaluator to compare the results. - - CuTe DSL now supports Python 3.14 for both x86_64 and aarch64 - - Runtime Pointer/Tensor/FakeTensor now supports __cache_key__, providing a stable, hashable representation that simplifies and improves compiled function caching. - -* More examples of authorizing peak-performance kernels - - [SM103 batched 3xFP4 blockscaled GEMM kernel](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/sm103_dense_blockscaled_gemm_persistent.py) - - Mixed input FMHA decode example with support for int4 KV (int8 KV supported in 4.3) - - New acc_scale grouped mixed input gemm kernel variant is introduced to deliver better performance for decoding cases. - - All mixed_input_gemm examples are moved into a separate folder `mixed_input_gemm`. Common utility functions are also extracted into mixed_input_host_utils.py under the same folder. + - Auto-deduced smem size for launching kernels + - Launch config `smem` now defaults to `None` for auto-calculating kernel shared memory usage, which is recommended unless manual control is required. + - Warnings will be raised when the manually set shared memory size is insufficient or exceeds the GPU maximum. + - The default shared memory usage calculation aligns with CUDA C++ static shared memory behavior, i.e. summing all allocations additively. + - An additional launch option `smem_merge_branch_allocs` is provided to merge shared memory allocations across mutually exclusive code branches, which is recommended for inlined mega-kernels to reduce total footprint. * Bug fixing and improvements - - Fixed an issue that both branches of if are executed - - Fixed `cute.printf` with f-string - - Fixed an indexing issue of scalar tensor - - Fixed small K reference check error for cta_tile_n = 256 case with overlapping accumulator optimization in [Blackwell SM100 persistent dense blockscaled GEMM with static scheduling](https://github.com/NVIDIA/cutlass/tree/main/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py). - - Fixed a segfault issue with tvm-ffi on aarch64 - - Fixed Hopper FMHA causal attention performance regression on CUDA toolkit 13.1 by - optimizing mbarrier synchronization to avoid unnecessary convergence barriers. - - Fix kernel loading race condition when multiple GPU are present in the same process in JAX. + - Improved source code correlation for profiling/debugging -* API changes - - Deprecate get_num_tmem_alloc_cols from blackwell_helpers.py. Use the one from tmem_allocator.py instead. - - Deprecate SM100_TMEM_CAPACITY_COLUMNS and SM100_TMEM_MIN_ALLOC_COLUMNS. - - LdMatrix16x16x8bOp and StMatrix16x8x8bOp now require explicit transpose=True when calling __init__, to avoid ambiguity in data transposition. - - LdMatrix16x16x8bOp copy traits updated to be faithful to PTX without permutations. Permuted variant is renamed to LdMatrix16x8x8bOp. - - Grouped GEMM example takes the argument --host_problem_shape_available. If the argument is provided, grid is carved out based upon the host problem shapes, otherwise, we launch maximum possible SMs. - - hardware_info.get_max_active_cluster support pass in specific stream to query. Useful for green context based SM partition. - - group_bulk_copy_modes in async bulk copy example is now deprecated, use group_modes directly instead. - - Deprecate nvvm wrapper from using nvvm enum, use str instead. - - cute.arch.calc_packed_f32x2_op default enable ftz to default disable ftz - - In CuTe DSL with CTK 13.1, following APIs in cutlass.cute.arch now require string literal instead of enum as argument: - + fence_proxy - + fence_view_async_tmem_op - + calc_packed_f32x2_op - + warp_redux_sync - + atomic_add - + atomic_and - + atomic_or - + atomic_xor - + atomic_max - + atomic_min - + atomic_exch - + atomic_cas - + store - + load - -* Use 'Advanced control file' for mixed input gemm examples for better performance. - - Advanced control file is an experimental feature of CUDA compiler. The controls file contains internal compiler settings tuned for specific kernels with a specific version of CUDA toolkit to get better GPU kernel code. More details and documentation on how to create these controls files will be provided in future CUDA toolkit release. Note: The advanced compiler control file is not expected to work for kernels that it was not tuned for. There is no compatibility guarantee, and the controls file will not work for CUDA toolkit with a different version. - -## CUTLASS C++ -* Add [example 93](https://github.com/NVIDIA/cutlass/tree/main/examples/93_blackwell_low_latency_gqa/) for Blackwell low latency generation phase GQA kernel. - - Flash Decoding with cluster reduction. - - Kernel design details please check [Readme](https://github.com/NVIDIA/cutlass/tree/main/examples/93_blackwell_low_latency_gqa/readme.md). -* Add Blackwell SM100 State Space Decomposition (SSD) kernel in [example 112](https://github.com/NVIDIA/cutlass/tree/main/examples/112_blackwell_ssd). -* Add Hopper SM90 State Space Decomposition (SSD) kernel in [example 111](https://github.com/NVIDIA/cutlass/tree/main/examples/111_hopper_ssd). -* Add Hopper e2m1 to fp32 optimized conversion and e2m1 * TF32 tensor core GEMM. - - Enable [example 55](https://github.com/NVIDIA/cutlass/tree/main/examples/55_hopper_mixed_dtype_gemm) with TF32 support -* Add [example 94](https://github.com/NVIDIA/cutlass/tree/main/examples/94_ada_fp8_blockwise/) for Ada FP8xFP8 -> BF16 GEMM with blockwise dequantization of input matrices in the MMA loop with FP32 accumulation. -* Add support for arbitrary application-provided strides for block-scale tensors. - - Users and applications now must pass valid block-scale strides in all cases, even when the tensor is packed. -* Support 4x blockscaled public ptx for CUDA 13.1. -* Enable Blackwell SM120f compilation of examples and exposes NVFP4/MX Grouped GEMM in the CUTLASS Profiler. -* Allow non-static `TmaGbasis` in `AuxTmaParams`. - - Some cases in attention kernel may require non-static `tma_gbasis`. - - Relax the restriction on `TmaGbasis` parameter of `AuxTmaParams` and users are allowed to manually construct a dynamic gbasis. +### CUTLASS C++ +* Add [example 95](https://github.com/NVIDIA/cutlass/tree/main/examples/95_blackwell_gemm_green_context) to support green context SM partition + - Enables launching GEMM on stream with partial SM allocation. * Fix some kernel issues: - - Fix MSVC pre process issue. - - Fix a self assign issue in GEMV kernel. - - Fix a TMA descriptor bug where the CUDA driver is not properly setting the OOB address gen mode correctly. - - Fix memory fence for clc scheduler in Blackwell SM120 pingpong kernel. - - Fix missing SMEM alignment in Blackwell SM120 scale factors. - - Fix a PDL issue for grouped gemm. - - Fix divide-by-zero issue in canimplement for sm100 implicit gemm kernels. - - Fix cluster swizzle for Grouped GEMMs. - + Move host-side swizzling heuristics to device. - + Apply swizzle per group based on problem shape and max swizzle size. - + Improve examples and unit tests. + - Fix l2_capacity=0 handling in Blackwell SM100/SM120 kernel templates + - Fix CUTLASS clang build issues * Fix some profiler issues: - - Fix a core dump issue for nvfp4 grouped GEMM kernel. - - Fix inconsistent GEMM verification logic. - - Rework grouped gemm verification logic for different types. - - Fix api break change in using nvMatmulHeuristics. -* Fix some failed links under `media/docs`. + - Add missing reference kernels for blockwise GEMM profiler +* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! +* Optimal code generation with CUDA toolkit versions 13.2. Note: CUTLASS 4.x builds are known to be down on Windows platforms for all CUDA toolkits. CUTLASS team is working on a fix. diff --git a/examples/80_blackwell_geforce_sparse_gemm/CMakeLists.txt b/examples/80_blackwell_geforce_sparse_gemm/CMakeLists.txt index 5c97a9ba1..ffd20f810 100644 --- a/examples/80_blackwell_geforce_sparse_gemm/CMakeLists.txt +++ b/examples/80_blackwell_geforce_sparse_gemm/CMakeLists.txt @@ -27,7 +27,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -if (CUTLASS_NVCC_ARCHS MATCHES "120a|120f|121a") +if (CUTLASS_NVCC_ARCHS MATCHES "120a|121a") cutlass_example_add_executable( 80a_blackwell_geforce_mxfp8_bf16_sparse_gemm 80a_blackwell_geforce_mxfp8_bf16_sparse_gemm.cu diff --git a/examples/95_blackwell_gemm_green_context/95_blackwell_gemm_green_context.cu b/examples/95_blackwell_gemm_green_context/95_blackwell_gemm_green_context.cu new file mode 100644 index 000000000..076f94926 --- /dev/null +++ b/examples/95_blackwell_gemm_green_context/95_blackwell_gemm_green_context.cu @@ -0,0 +1,712 @@ +/*************************************************************************************************** + * Copyright (c) 2026 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/* + * Blackwell FP16 GEMM with Green Context + * + * The two schedulers share identical kernel configurations (element types, tile shapes, + * cluster shape, epilogue/mainloop builders) and differ only in: + * + * TileScheduler type: + * - DynamicPersistentScheduler (CLC-based) + * - StaticPersistentScheduler + * + * Use --scheduler=dynamic (default) or --scheduler=static to select the scheduler. + */ + +#include + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "helper.h" + +/// Panic wrapper for CUDA Driver API errors (CUresult). +#define CUDA_DRIVER_CHECK(status) \ + do { \ + CUresult error = status; \ + if (error != CUDA_SUCCESS) { \ + const char *error_string; \ + cuGetErrorString(error, &error_string); \ + std::cerr << "Got CUDA driver error: " << error \ + << " (" << error_string << ")" \ + << " at: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +/// Alias for CUDA Runtime API error checks (cudaError_t). +/// CUDA_CHECK is defined in helper.h; this alias makes the API boundary explicit. +#define CUDA_RUNTIME_CHECK CUDA_CHECK + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Common GEMM kernel configuration (shared by both schedulers) +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace gemm_common_config { + + using namespace cute; + + using ElementA = half_t; + using LayoutA = cutlass::layout::RowMajor; + constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using ElementB = half_t; + using LayoutB = cutlass::layout::ColumnMajor; + constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using ElementAccumulator = float; + using ArchTag = cutlass::arch::Sm100; + using OperatorClass = cutlass::arch::OpClassTensorOp; + + using MmaTileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized2Sm; + using KernelScheduleType = cutlass::gemm::KernelTmaWarpSpecialized2SmSm100; + +} // namespace gemm_common_config + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Dynamic Persistent CLC kernel +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace dynamic_kernel_config { + + using namespace cute; + using namespace gemm_common_config; + + using TileScheduler = cutlass::gemm::DynamicPersistentScheduler; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementC, LayoutC, AlignmentC, + EpilogueScheduleType + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler>; + +} // namespace dynamic_kernel_config + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Static Persistent kernel +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace static_kernel_config { + + using namespace cute; + using namespace gemm_common_config; + + using TileScheduler = cutlass::gemm::StaticPersistentScheduler; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementC, LayoutC, AlignmentC, + EpilogueScheduleType + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileScheduler>; + +} // namespace static_kernel_config + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Gemm device adapter types +///////////////////////////////////////////////////////////////////////////////////////////////// + +using DynamicGemm = cutlass::gemm::device::GemmUniversalAdapter; +using StaticGemm = cutlass::gemm::device::GemmUniversalAdapter; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Common type aliases and global data (identical across both kernels) +///////////////////////////////////////////////////////////////////////////////////////////////// + +using ElementA = gemm_common_config::ElementA; +using ElementB = gemm_common_config::ElementB; +using ElementC = gemm_common_config::ElementC; +using LayoutA = gemm_common_config::LayoutA; +using LayoutB = gemm_common_config::LayoutB; +using LayoutC = gemm_common_config::LayoutC; +using ElementAccumulator = gemm_common_config::ElementAccumulator; + +using DeviceGemmReference = cutlass::reference::device::Gemm< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, ElementAccumulator>; + +using StrideA = typename DynamicGemm::GemmKernel::StrideA; +using StrideB = typename DynamicGemm::GemmKernel::StrideB; +using StrideC = typename DynamicGemm::GemmKernel::StrideC; +using StrideD = typename DynamicGemm::GemmKernel::StrideD; + +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; +uint64_t seed = 0; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct Options { + + bool help; + bool use_cuda_graph; + int m, n, k; + float alpha, beta; + int iterations; + int swizzle; + int max_num_sm; + std::string raster_order; + std::string scheduler; + + Options(): + help(false), + use_cuda_graph(false), + m(8192), n(8192), k(8192), + alpha(1.f), beta(0.f), + iterations(30), + swizzle(0), + max_num_sm(0), + raster_order("heuristic"), + scheduler("dynamic") + { } + + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("swizzle", swizzle); + cmd.get_cmd_line_argument("max_num_sm", max_num_sm); + cmd.get_cmd_line_argument("raster_order", raster_order, std::string("heuristic")); + cmd.get_cmd_line_argument("scheduler", scheduler, std::string("dynamic")); + use_cuda_graph = cmd.check_cmd_line_flag("use_cuda_graph"); + } + + using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions; + RasterOrderOptions get_raster_order() const { + if (raster_order == "along_m") { + return RasterOrderOptions::AlongM; + } else if (raster_order == "along_n") { + return RasterOrderOptions::AlongN; + } else { + return RasterOrderOptions::Heuristic; + } + } + + std::ostream & print_usage(std::ostream &out) const { + + out << "95_blackwell_gemm_green_context\n\n" + << " Blackwell FP16 GEMM with Green Context support.\n" + << " Supports both Dynamic Persistent CLC and Static Persistent schedulers.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --scheduler= Tile scheduler: 'dynamic' (default) or 'static'\n" + << " dynamic: DynamicPersistentScheduler (CLC-based)\n" + << " static: StaticPersistentScheduler\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --swizzle= Cluster rasterization swizzle\n" + << " --raster_order= Raster order: 'heuristic' (default), 'along_m', or 'along_n'\n" + << " --max_num_sm= Max number of SMs for green context partition (0 = use all SMs, no green context)\n" + << " --use_cuda_graph If specified, use CUDA graph capture/replay for profiling iterations\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << " # Dynamic scheduler, all SMs (no green context)\n" + << " $ 95_blackwell_gemm_green_context --scheduler=dynamic --m=8192 --n=8192 --k=8192\n\n" + << " # Static scheduler, 120-SM green context partition\n" + << " $ 95_blackwell_gemm_green_context --scheduler=static --m=8192 --n=8192 --k=8192 --max_num_sm=120\n\n"; + + return out; + } + + double gflops(double runtime_s) const + { + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +void initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = Element(2); + scope_min = Element(0); + } else if (bits_input <= 8) { + scope_max = Element(2); + scope_min = Element(-2); + } else { + scope_max = Element(8); + scope_min = Element(-8); + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); +} + +void initialize(const Options &options) { + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1}); + + block_A.reset(options.m * options.k); + block_B.reset(options.k * options.n); + block_C.reset(options.m * options.n); + block_D.reset(options.m * options.n); + block_ref_D.reset(options.m * options.n); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); +} + +/// Populates a Gemm::Arguments structure from the given commandline options. +/// +/// The static persistent scheduler needs hw_info (sm_count, max_active_clusters) to compute +/// the launch grid size. When overlapping with another kernel via green context partitioning, +/// the green context stream MUST be passed so that cudaOccupancyMaxActiveClusters returns the +/// partition-scoped value. The dynamic persistent scheduler does not use hw_info. +/// +/// We always set hw_info here for uniformity; it is harmless for the dynamic scheduler. +template +typename Gemm::Arguments args_from_options(const Options &options, cudaStream_t stream) +{ + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, 1}, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} + }; + + arguments.scheduler.max_swizzle_size = options.swizzle; + arguments.scheduler.raster_order = options.get_raster_order(); + + // ===================================================================================== + // IMPORTANT: Static persistent scheduler needs hw_info (sm_count, max_active_clusters) + // to compute the correct launch grid size. + // + // When the static persistent kernel overlaps with other kernel (e.g. via + // green context partitioning), you MUST pass the green context stream to + // make_kernel_hardware_info() so that cudaOccupancyMaxActiveClusters queries the max + // active clusters for that SM partition -- NOT the full device. + // + // If the stream is not passed, the query returns the full-device max active clusters, + // which leads to an oversized persistent grid that exceeds the partition's capacity, + // causing performance issues (because static scheduler stride to next work with launch grid size) + // ===================================================================================== + arguments.hw_info = cutlass::KernelHardwareInfo::make_kernel_hardware_info( + 0 /* device_id */, 0 /* sm_count: auto-query */, 0 /* max_active_clusters: auto-query */, stream); + + return arguments; +} + +template +bool verify(const Options &options) { + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({options.m, options.k})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({options.k, options.n})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({options.m, options.n})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutC::packed({options.m, options.n})); + + DeviceGemmReference gemm_reference; + + gemm_reference( + {options.m, options.n, options.k}, + ElementAccumulator(options.alpha), + ref_A, + ref_B, + ElementAccumulator(options.beta), + ref_C, + ref_D); + + CUDA_RUNTIME_CHECK(cudaDeviceSynchronize()); + + bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; +} + +/// Execute a given example GEMM computation. +/// Stream is REQUIRED -- pass cudaStreamDefault (0) for the full device, or a green context +/// stream for a partitioned launch. +template +int run(const Options &options, cudaStream_t stream) +{ + initialize(options); + + Gemm gemm; + + auto arguments = args_from_options(options, stream); + + dim3 grid = Gemm::get_grid_shape(arguments); + std::cout << " hw_info: sm_count=" << arguments.hw_info.sm_count + << ", max_active_clusters=" << arguments.hw_info.max_active_clusters << std::endl; + std::cout << " Launch grid: (" << grid.x << ", " << grid.y << ", " << grid.z << ")" << std::endl; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(gemm.can_implement(arguments)); + + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + CUTLASS_CHECK(gemm.run(stream)); + + CUDA_RUNTIME_CHECK(cudaStreamSynchronize(stream)); + + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + if (options.iterations > 0) + { + GpuTimer timer; + + if (options.use_cuda_graph) { + // cudaStreamBeginCapture cannot capture on the legacy default stream (stream 0). + cudaStream_t capture_stream = (stream == cudaStreamDefault) ? cudaStreamPerThread : stream; + + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + cudaGraph_t graph; + cudaGraphExec_t graph_exec; + CUDA_RUNTIME_CHECK(cudaStreamBeginCapture(capture_stream, cudaStreamCaptureModeGlobal)); + CUTLASS_CHECK(gemm.run(capture_stream)); + CUDA_RUNTIME_CHECK(cudaStreamEndCapture(capture_stream, &graph)); + CUDA_RUNTIME_CHECK(cudaGraphInstantiate(&graph_exec, graph, nullptr, nullptr, 0)); + + std::cout << " Using CUDA Graph for " << options.iterations << " iterations" << std::endl; + + timer.start(stream); + for (int iter = 0; iter < options.iterations; ++iter) { + CUDA_RUNTIME_CHECK(cudaGraphLaunch(graph_exec, stream)); + } + timer.stop(); + + CUDA_RUNTIME_CHECK(cudaGraphExecDestroy(graph_exec)); + CUDA_RUNTIME_CHECK(cudaGraphDestroy(graph)); + } + else { + timer.start(stream); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run(stream)); + } + timer.stop(); + } + + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +/// Dispatches to the correct Gemm type and handles green context setup. +template +int dispatch(const Options &options, int current_device_id) +{ + if (options.max_num_sm > 0) { + // + // Green Context path: partition SMs and launch kernel on primary partition + // + + CUDA_DRIVER_CHECK(cuInit(0)); + + CUdevice cu_device; + CUcontext primary_context; + CUDA_DRIVER_CHECK(cuDeviceGet(&cu_device, current_device_id)); + CUDA_DRIVER_CHECK(cuDevicePrimaryCtxRetain(&primary_context, cu_device)); + CUDA_DRIVER_CHECK(cuCtxSetCurrent(primary_context)); + + CUdevResource device_resource; + CUDA_DRIVER_CHECK(cuDeviceGetDevResource(cu_device, &device_resource, CU_DEV_RESOURCE_TYPE_SM)); + std::cout << " Device SM count: " << device_resource.sm.smCount << std::endl; + + if (options.max_num_sm >= static_cast(device_resource.sm.smCount)) { + std::cout << " --max_num_sm (" << options.max_num_sm + << ") >= device SM count (" << device_resource.sm.smCount + << "), no green context split, using a dedicated stream with all SMs" << std::endl; + + cudaStream_t stream; + CUDA_RUNTIME_CHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + int rc = run(options, stream); + CUDA_RUNTIME_CHECK(cudaStreamDestroy(stream)); + CUDA_DRIVER_CHECK(cuDevicePrimaryCtxRelease(cu_device)); + return rc; + } + + CUdevResource primary_partition_resource; + CUdevResource remaining_partition_resource; + unsigned int num_groups = 1; +#if CUDA_VERSION >= 13000 + unsigned int sm_alignment = device_resource.sm.smCoscheduledAlignment; + unsigned int aligned_sm = (static_cast(options.max_num_sm) / sm_alignment) * sm_alignment; + if (aligned_sm == 0) { + aligned_sm = sm_alignment; + } + std::cout << " SM coscheduled alignment: " << sm_alignment << std::endl; + std::cout << " Requested --max_num_sm: " << options.max_num_sm + << ", aligned (round down): " << aligned_sm << std::endl; + CUDA_DRIVER_CHECK(cuDevSmResourceSplitByCount( + &primary_partition_resource, + &num_groups, + &device_resource, + &remaining_partition_resource, + 0, + aligned_sm)); +#else + std::cout << " Requested --max_num_sm: " << options.max_num_sm << std::endl; + CUDA_DRIVER_CHECK(cuDevSmResourceSplitByCount( + &primary_partition_resource, + &num_groups, + &device_resource, + &remaining_partition_resource, + 0, + static_cast(options.max_num_sm))); +#endif + + std::cout << " Primary partition SM count (for GEMM): " << primary_partition_resource.sm.smCount << std::endl; + std::cout << " Remaining partition SM count: " << remaining_partition_resource.sm.smCount << std::endl; + + CUdevResourceDesc primary_partition_desc; + CUdevResourceDesc remaining_partition_desc; + CUDA_DRIVER_CHECK(cuDevResourceGenerateDesc(&primary_partition_desc, &primary_partition_resource, 1)); + CUDA_DRIVER_CHECK(cuDevResourceGenerateDesc(&remaining_partition_desc, &remaining_partition_resource, 1)); + + CUgreenCtx primary_partition_green_ctx; + CUgreenCtx remaining_partition_green_ctx; + CUDA_DRIVER_CHECK(cuGreenCtxCreate(&primary_partition_green_ctx, primary_partition_desc, cu_device, CU_GREEN_CTX_DEFAULT_STREAM)); + CUDA_DRIVER_CHECK(cuGreenCtxCreate(&remaining_partition_green_ctx, remaining_partition_desc, cu_device, CU_GREEN_CTX_DEFAULT_STREAM)); + + // Remaining partition stream is not used in this example but instantiated to show how. + CUstream primary_partition_cu_stream; + CUstream remaining_partition_cu_stream; + CUDA_DRIVER_CHECK(cuGreenCtxStreamCreate(&primary_partition_cu_stream, primary_partition_green_ctx, CU_STREAM_NON_BLOCKING, 0)); + CUDA_DRIVER_CHECK(cuGreenCtxStreamCreate(&remaining_partition_cu_stream, remaining_partition_green_ctx, CU_STREAM_NON_BLOCKING, 0)); + + cudaStream_t primary_partition_stream = static_cast(primary_partition_cu_stream); + int rc = run(options, primary_partition_stream); + + CUDA_DRIVER_CHECK(cuStreamDestroy(primary_partition_cu_stream)); + CUDA_DRIVER_CHECK(cuStreamDestroy(remaining_partition_cu_stream)); + CUDA_DRIVER_CHECK(cuGreenCtxDestroy(primary_partition_green_ctx)); + CUDA_DRIVER_CHECK(cuGreenCtxDestroy(remaining_partition_green_ctx)); + CUDA_DRIVER_CHECK(cuDevicePrimaryCtxRelease(cu_device)); + + return rc; + } + else { + return run(options, cudaStreamDefault); + } +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 9)) { + std::cerr << "This example requires CUDA 12.9 or newer." << std::endl; + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_RUNTIME_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_RUNTIME_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + + if (props.major != 10 || props.minor != 0) { + std::cerr << "This example requires a GPU with compute capability 100a." << std::endl; + return 0; + } + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + if (options.scheduler == "static") { + std::cout << "Using StaticPersistentScheduler" << std::endl; + return dispatch(options, current_device_id); + } + else if (options.scheduler == "dynamic") { + std::cout << "Using DynamicPersistentScheduler" << std::endl; + return dispatch(options, current_device_id); + } + else { + std::cerr << "Unknown scheduler: '" << options.scheduler + << "'. Use --scheduler=dynamic or --scheduler=static." << std::endl; + return 1; + } +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/95_blackwell_gemm_green_context/CMakeLists.txt b/examples/95_blackwell_gemm_green_context/CMakeLists.txt new file mode 100644 index 000000000..9f9a2173a --- /dev/null +++ b/examples/95_blackwell_gemm_green_context/CMakeLists.txt @@ -0,0 +1,54 @@ + +# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +set(TEST_DYNAMIC_MAX_NUM_SM_64 --scheduler=dynamic --max_num_sm=64) +set(TEST_DYNAMIC_MAX_NUM_SM_120 --scheduler=dynamic --max_num_sm=120) +set(TEST_DYNAMIC_MAX_NUM_SM_148 --scheduler=dynamic --max_num_sm=148) +set(TEST_DYNAMIC_MAX_NUM_SM_160 --scheduler=dynamic --max_num_sm=160) +set(TEST_STATIC_MAX_NUM_SM_64 --scheduler=static --max_num_sm=64) +set(TEST_STATIC_MAX_NUM_SM_120 --scheduler=static --max_num_sm=120) +set(TEST_STATIC_MAX_NUM_SM_148 --scheduler=static --max_num_sm=148) +set(TEST_STATIC_MAX_NUM_SM_160 --scheduler=static --max_num_sm=160) + +if(CUTLASS_NVCC_ARCHS MATCHES "100a|100f|101a|101f|103a|103f") +cutlass_example_add_executable( + 95_blackwell_gemm_green_context + 95_blackwell_gemm_green_context.cu + TEST_COMMAND_OPTIONS + TEST_DYNAMIC_MAX_NUM_SM_64 + TEST_DYNAMIC_MAX_NUM_SM_120 + TEST_DYNAMIC_MAX_NUM_SM_148 + TEST_DYNAMIC_MAX_NUM_SM_160 + TEST_STATIC_MAX_NUM_SM_64 + TEST_STATIC_MAX_NUM_SM_120 + TEST_STATIC_MAX_NUM_SM_148 + TEST_STATIC_MAX_NUM_SM_160 +) +endif() diff --git a/examples/95_blackwell_gemm_green_context/README.md b/examples/95_blackwell_gemm_green_context/README.md new file mode 100644 index 000000000..f7138c4b7 --- /dev/null +++ b/examples/95_blackwell_gemm_green_context/README.md @@ -0,0 +1,114 @@ +# Example 95: Blackwell GEMM with Green Context + +[Green Context](https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html) is a lightweight method to partition GPU SM resources. + +This example demonstrates a Blackwell FP16 GEMM with green context support using two tile scheduler variants, selectable at runtime via `--scheduler`: + +1. **Dynamic Persistent CLC** (`--scheduler=dynamic`, default): Uses `DynamicPersistentScheduler` (Cluster Launch Control based). No modification needed for green context -- simply launch kernel onto partition stream. +2. **Static Persistent** (`--scheduler=static`): Uses `StaticPersistentScheduler`. For green context, use partition stream to query max active cluster; modify launch grid based on partition stream max active cluster. + +For SM90 Hopper dynamic / static kernels, modifying the launch grid is required. + +## Build + +From the CUTLASS build directory: + +```shell +# Configure (only examples, SM100a) +cmake \ + -DCUTLASS_NVCC_ARCHS=100a \ + -DCUTLASS_ENABLE_EXAMPLES=ON \ + -DCUTLASS_ENABLE_TESTS=OFF \ + -DCUTLASS_ENABLE_LIBRARY=OFF \ + -DCUTLASS_ENABLE_PROFILER=OFF \ + -DCMAKE_BUILD_TYPE=Release + +# Build only this example +make 95_blackwell_gemm_green_context -j$(nproc) +``` + +## Run + +### Dynamic Persistent CLC scheduler + +#### Without green context (all SMs) + +```shell +./examples/95_blackwell_gemm_green_context/95_blackwell_gemm_green_context \ + --scheduler=dynamic --m=8192 --n=8192 --k=8192 --iterations=30 +``` + +#### With green context (partitioned SMs) + +Use `--max_num_sm` to specify the number of SMs for the primary partition (GEMM workload). +The remaining SMs are assigned to the remaining partition. + +```shell +./examples/95_blackwell_gemm_green_context/95_blackwell_gemm_green_context \ + --scheduler=dynamic --m=8192 --n=8192 --k=8192 --max_num_sm=120 --iterations=30 --raster_order=along_m + +./examples/95_blackwell_gemm_green_context/95_blackwell_gemm_green_context \ + --scheduler=dynamic --m=8192 --n=8192 --k=8192 --max_num_sm=120 --iterations=30 --raster_order=along_n +``` + +### Static Persistent scheduler + +#### Without green context (all SMs) + +```shell +./examples/95_blackwell_gemm_green_context/95_blackwell_gemm_green_context \ + --scheduler=static --m=8192 --n=8192 --k=8192 --iterations=30 +``` + +#### With green context (partitioned SMs) + +```shell +./examples/95_blackwell_gemm_green_context/95_blackwell_gemm_green_context \ + --scheduler=static --m=8192 --n=8192 --k=8192 --max_num_sm=120 --iterations=30 --raster_order=along_m + +./examples/95_blackwell_gemm_green_context/95_blackwell_gemm_green_context \ + --scheduler=static --m=8192 --n=8192 --k=8192 --max_num_sm=120 --iterations=30 --raster_order=along_n +``` + + +## Nsight Systems Profiling + +Use `nsys profile` to capture the kernel execution under green context partitioning. +In the Nsys UI, you should see the green context section on the left panel, and the GEMM kernel +launched onto that green context partition. + +### Profile Dynamic Persistent CLC scheduler + +#### Without green context + +```shell +nsys profile -o gemm_dynamic_all_sm \ + ./examples/95_blackwell_gemm_green_context/95_blackwell_gemm_green_context \ + --scheduler=dynamic --m=8192 --n=8192 --k=8192 --iterations=30 +``` + +#### With green context (120 SMs for GEMM) + +```shell +nsys profile -o gemm_dynamic_green_ctx_120sm \ + ./examples/95_blackwell_gemm_green_context/95_blackwell_gemm_green_context \ + --scheduler=dynamic --m=8192 --n=8192 --k=8192 --max_num_sm=120 --iterations=30 +``` + +### Profile Static Persistent scheduler + +#### Without green context + +```shell +nsys profile -o gemm_static_all_sm \ + ./examples/95_blackwell_gemm_green_context/95_blackwell_gemm_green_context \ + --scheduler=static --m=8192 --n=8192 --k=8192 --iterations=30 +``` + +#### With green context (120 SMs for GEMM) + +```shell +nsys profile -o gemm_static_green_ctx_120sm \ + ./examples/95_blackwell_gemm_green_context/95_blackwell_gemm_green_context \ + --scheduler=static --m=8192 --n=8192 --k=8192 --max_num_sm=120 --iterations=30 +``` diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index f27834d92..d7670f642 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -172,6 +172,7 @@ foreach(EXAMPLE 92_blackwell_moe_gemm 93_blackwell_low_latency_gqa 94_ada_fp8_blockwise + 95_blackwell_gemm_green_context 111_hopper_ssd 112_blackwell_ssd ) diff --git a/examples/python/CuTeDSL/ampere/dynamic_smem_size.py b/examples/python/CuTeDSL/ampere/dynamic_smem_size.py index 9bbc0b46e..5cf5b80aa 100644 --- a/examples/python/CuTeDSL/ampere/dynamic_smem_size.py +++ b/examples/python/CuTeDSL/ampere/dynamic_smem_size.py @@ -26,9 +26,11 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + import cutlass.cute as cute import cutlass + """ Example of automatic shared memory size computation for configuring kernel launch @@ -51,11 +53,15 @@ class SharedData: @cute.kernel -def kernel(): +def kernel_static(): """ Example kernel that allocates shared memory. The total allocation will be automatically calculated when smem=None. """ + tidx, _, _ = cute.arch.block_idx() + if tidx == 0: + cute.printf("Running kernel_static") + allocator = cutlass.utils.SmemAllocator() # Allocate various types of shared memory @@ -68,6 +74,8 @@ def kernel(): byte_alignment=16, swizzle=None, ) + + cute.printf("Kernel launch smem size: {}", cute.arch.dynamic_smem_size()) return @@ -79,7 +87,7 @@ def kernel_no_smem(): """ tidx, _, _ = cute.arch.block_idx() if tidx == 0: - cute.printf("Hello world") + cute.printf("Running kernel_no_smem") return @@ -89,26 +97,49 @@ if __name__ == "__main__": print("Launching kernel with auto smem size. (launch config `smem=None`)") - # Compile the example + # Compile the static example @cute.jit - def launch_kernel1(): - k = kernel() - k.launch( + def launch_kernelno_smem(): + kernel_no_smem().launch( grid=(1, 1, 1), block=(1, 1, 1), ) - print(f"Kernel recorded internal smem usage: {k.smem_usage()}") + + # -------- + print(f" > Run {kernel_no_smem.__name__}") + func = cute.compile(launch_kernelno_smem) + func() + cutlass.cuda.stream_sync(cutlass.cuda.default_stream()) @cute.jit - def launch_kernel2(): - k = kernel_no_smem() - k.launch( + def launch_kernel_static(): + kernel_static().launch( grid=(1, 1, 1), block=(1, 1, 1), + # smem=None + # auto infer launch kernel static smem usage ) - print(f"Kernel recorded internal smem usage: {k.smem_usage()}") - cute.compile(launch_kernel1) - cute.compile(launch_kernel2) + # -------- + print(f" > Run {kernel_static.__name__} with sufficient smem") + func = cute.compile(launch_kernel_static) + func() + cutlass.cuda.stream_sync(cutlass.cuda.default_stream()) + + @cute.jit + def launch_kernel_static_insufficient(): + kernel_static().launch( + grid=(1, 1, 1), + block=(1, 1, 1), + # launch kernel with static smem usage exceeds cfg + # show warning + smem=16, + ) + + # -------- + print(f" > Run {kernel_static.__name__} with insufficient smem, show warning:") + func = cute.compile(launch_kernel_static_insufficient) + func() + cutlass.cuda.stream_sync(cutlass.cuda.default_stream()) print("PASS") diff --git a/examples/python/CuTeDSL/ampere/hstu_attention.py b/examples/python/CuTeDSL/ampere/hstu_attention.py index 3b537792a..fbe6f7af4 100644 --- a/examples/python/CuTeDSL/ampere/hstu_attention.py +++ b/examples/python/CuTeDSL/ampere/hstu_attention.py @@ -265,7 +265,6 @@ class HSTUAttentionForwardAmpere(object): ).launch( grid=grid_dim, block=[self._num_threads, 1, 1], - smem=SharedStorage.size_in_bytes(), stream=stream, ) diff --git a/examples/python/CuTeDSL/ampere/smem_allocator.py b/examples/python/CuTeDSL/ampere/smem_allocator.py index ea004ead0..ffb4e19cb 100644 --- a/examples/python/CuTeDSL/ampere/smem_allocator.py +++ b/examples/python/CuTeDSL/ampere/smem_allocator.py @@ -129,8 +129,8 @@ def kernel( # ptr> # ptr> print(struct_in_smem.a.data_ptr()) - print(struct_in_smem.b) - print(struct_in_smem.c.real) + print(struct_in_smem.b.ptr) + print(struct_in_smem.c.real.ptr) # ptr> print(section_in_smem) # ptr> @@ -138,6 +138,17 @@ def kernel( # tensor> o (16,4):(1,16)> print(tensor_in_smem) + # assign struct member array element + cute.printf("struct_in_smem.a[0] = {}", struct_in_smem.a[0]) + struct_in_smem.a[0] = 2 + cute.printf("struct_in_smem.a[0] = {}", struct_in_smem.a[0]) + + # assign struct member scalar + cute.printf("struct_in_smem.b.ptr = {}", struct_in_smem.b.ptr) + cute.printf("struct_in_smem.b: value = {}", struct_in_smem.b.ptr.load()) + struct_in_smem.b = 16 + cute.printf("struct_in_smem.b: value = {}", struct_in_smem.b.ptr.load()) + # fill MemRange tensor in struct and copy to dst a_tensor = struct_in_smem.a.get_tensor(cute.make_layout((8, 4))) a_tensor.fill(const_a) @@ -169,7 +180,9 @@ def host( ): # Note: Shared Memory size is automatically calculated now kernel(const_a, dst_a, const_b, dst_b, const_c, dst_c).launch( - grid=(1, 1, 1), block=(1, 1, 1) + grid=(1, 1, 1), + block=(1, 1, 1), + # Automatically calculate the launch kernel shared memory usage when `smem=None` ) diff --git a/examples/python/CuTeDSL/ampere/tensorop_gemm.py b/examples/python/CuTeDSL/ampere/tensorop_gemm.py index ad3fe0f0b..0c22c98f2 100644 --- a/examples/python/CuTeDSL/ampere/tensorop_gemm.py +++ b/examples/python/CuTeDSL/ampere/tensorop_gemm.py @@ -175,15 +175,6 @@ class TensorOpGemm: (self.cta_tiler[0], self.cta_tiler[1]), ) - # Shared memory allocated for operations with A, B will be - # overwritten for operations on C. This is to improve performance - # by reducing the size of shared memory requested by each block - smem_size = max( - cute.size_in_bytes(mC.element_type, sC_layout), - cute.size_in_bytes(mA.element_type, sA_layout) - + cute.size_in_bytes(mB.element_type, sB_layout), - ) - # /////////////////////////////////////////////////////////////////////////////// # Tiled copy: # The majorness of tA/tB/tC follows the majorness of gA/gB/gC, @@ -282,7 +273,6 @@ class TensorOpGemm: ).launch( grid=rasterization_remap_grid_dim, block=[self.num_threads, 1, 1], - smem=smem_size, ) @cute.kernel @@ -382,14 +372,36 @@ class TensorOpGemm: # tAgA: (CPY, CPY_M, CPY_K, k) , tBgB: (CPY, CPY_N, CPY_K, k) # tAsA: (CPY, CPY_M, CPY_K, PIPE) , tBsB: (CPY, CPY_N, CPY_K, PIPE) # /////////////////////////////////////////////////////////////////////////////// + @cute.struct + class SharedStorageAB: + a: cute.struct.Align[ + cute.struct.MemRange[mA.element_type, cute.cosize(sA_layout)], + 16, + ] + b: cute.struct.Align[ + cute.struct.MemRange[mB.element_type, cute.cosize(sB_layout)], + 16, + ] + + @cute.struct + class SharedStorageC: + c: cute.struct.Align[ + cute.struct.MemRange[mC.element_type, cute.cosize(sC_layout)], + 16, + ] + # Shared memory buffer smem = cutlass.utils.SmemAllocator() - - sA = smem.allocate_tensor(mA.element_type, sA_layout, 16) - sB = smem.allocate_tensor(mB.element_type, sB_layout, 16) - sC = cute.make_tensor( - cute.recast_ptr(sA.iterator, dtype=self.c_dtype), sC_layout + # Shared memory allocated for operations with A, B will be + # overwritten for operations on C. This is to improve performance + # by reducing the size of shared memory requested by each block + storage = smem.allocate( + max(SharedStorageAB.size_in_bytes(), SharedStorageC.size_in_bytes()), + byte_alignment=16, ) + sA = SharedStorageAB(storage).a.get_tensor(sA_layout) + sB = SharedStorageAB(storage).b.get_tensor(sB_layout) + sC = SharedStorageC(storage).c.get_tensor(sC_layout) thr_copy_A = tiled_copy_A.get_slice(tidx) thr_copy_B = tiled_copy_B.get_slice(tidx) diff --git a/examples/python/CuTeDSL/blackwell/blockwise_gemm/blockwise_gemm.py b/examples/python/CuTeDSL/blackwell/blockwise_gemm/blockwise_gemm.py index a54e0fb03..f5dc77afc 100644 --- a/examples/python/CuTeDSL/blackwell/blockwise_gemm/blockwise_gemm.py +++ b/examples/python/CuTeDSL/blackwell/blockwise_gemm/blockwise_gemm.py @@ -549,7 +549,7 @@ class BlockwiseGemmKernel: cutlass.Int64, self.num_tile_stage * 2 ] epi_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1 * 2] - tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_dealloc_mbar: cutlass.Int64 tmem_holding_buf: cutlass.Int32 # (EPI_TILE_M, EPI_TILE_N, STAGE) sC: cute.struct.Align[ @@ -614,7 +614,6 @@ class BlockwiseGemmKernel: grid=grid, block=[self.threads_per_cta, 1, 1], cluster=(*self.cluster_shape_mn, 1), - smem=self.shared_storage.size_in_bytes(), stream=stream, min_blocks_per_mp=1, ) @@ -682,9 +681,6 @@ class BlockwiseGemmKernel: smem = utils.SmemAllocator() storage = smem.allocate(self.shared_storage) - tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr - tmem_holding_buf = storage.tmem_holding_buf - # Initialize mainloop ab_pipeline (barrier) and states ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 @@ -771,11 +767,11 @@ class BlockwiseGemmKernel: # Tensor memory dealloc barrier init tmem = utils.TmemAllocator( - storage.tmem_holding_buf, + storage.tmem_holding_buf.ptr, barrier_for_retrieve=self.tmem_alloc_barrier, allocator_warp_id=self.epilog_warp_id[0], is_two_cta=use_2cta_instrs, - two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr, ) # Cluster arrive after barrier init diff --git a/examples/python/CuTeDSL/blackwell/blockwise_gemm/contiguous_grouped_gemm.py b/examples/python/CuTeDSL/blackwell/blockwise_gemm/contiguous_grouped_gemm.py index bc99a15c8..5e7b1f5a2 100644 --- a/examples/python/CuTeDSL/blackwell/blockwise_gemm/contiguous_grouped_gemm.py +++ b/examples/python/CuTeDSL/blackwell/blockwise_gemm/contiguous_grouped_gemm.py @@ -568,7 +568,7 @@ class BlockwiseContiguousGroupedGemmKernel: cutlass.Int64, self.num_tile_stage * 2 ] epi_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1 * 2] - tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_dealloc_mbar: cutlass.Int64 tmem_holding_buf: cutlass.Int32 # (EPI_TILE_M, EPI_TILE_N, STAGE) sC: cute.struct.Align[ @@ -634,7 +634,6 @@ class BlockwiseContiguousGroupedGemmKernel: grid=grid, block=[self.threads_per_cta, 1, 1], cluster=(*self.cluster_shape_mn, 1), - smem=self.shared_storage.size_in_bytes(), stream=stream, min_blocks_per_mp=1, ) @@ -703,9 +702,6 @@ class BlockwiseContiguousGroupedGemmKernel: smem = utils.SmemAllocator() storage = smem.allocate(self.shared_storage) - tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr - tmem_holding_buf = storage.tmem_holding_buf - # Initialize mainloop ab_pipeline (barrier) and states ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 @@ -792,11 +788,11 @@ class BlockwiseContiguousGroupedGemmKernel: # Tensor memory dealloc barrier init tmem = utils.TmemAllocator( - storage.tmem_holding_buf, + storage.tmem_holding_buf.ptr, barrier_for_retrieve=self.tmem_alloc_barrier, allocator_warp_id=self.epilog_warp_id[0], is_two_cta=use_2cta_instrs, - two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr, ) # Cluster arrive after barrier init diff --git a/examples/python/CuTeDSL/blackwell/blockwise_gemm/masked_grouped_gemm.py b/examples/python/CuTeDSL/blackwell/blockwise_gemm/masked_grouped_gemm.py index 316a5e9aa..39f8c4a8b 100644 --- a/examples/python/CuTeDSL/blackwell/blockwise_gemm/masked_grouped_gemm.py +++ b/examples/python/CuTeDSL/blackwell/blockwise_gemm/masked_grouped_gemm.py @@ -567,7 +567,7 @@ class BlockwiseMaskedGroupedGemmKernel: cutlass.Int64, self.num_tile_stage * 2 ] epi_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1 * 2] - tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_dealloc_mbar: cutlass.Int64 tmem_holding_buf: cutlass.Int32 # (EPI_TILE_M, EPI_TILE_N, STAGE) sC: cute.struct.Align[ @@ -633,7 +633,6 @@ class BlockwiseMaskedGroupedGemmKernel: grid=grid, block=[self.threads_per_cta, 1, 1], cluster=(*self.cluster_shape_mn, 1), - smem=self.shared_storage.size_in_bytes(), stream=stream, min_blocks_per_mp=1, ) @@ -702,9 +701,6 @@ class BlockwiseMaskedGroupedGemmKernel: smem = utils.SmemAllocator() storage = smem.allocate(self.shared_storage) - tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr - tmem_holding_buf = storage.tmem_holding_buf - # Initialize mainloop ab_pipeline (barrier) and states ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 @@ -791,11 +787,11 @@ class BlockwiseMaskedGroupedGemmKernel: # Tensor memory dealloc barrier init tmem = utils.TmemAllocator( - storage.tmem_holding_buf, + storage.tmem_holding_buf.ptr, barrier_for_retrieve=self.tmem_alloc_barrier, allocator_warp_id=self.epilog_warp_id[0], is_two_cta=use_2cta_instrs, - two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr, ) # Cluster arrive after barrier init diff --git a/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py b/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py index b74d35dd0..3b9bbe486 100644 --- a/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py +++ b/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py @@ -615,7 +615,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel: ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] - tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_dealloc_mbar: cutlass.Int64 tmem_holding_buf: cutlass.Int32 # (EPI_TILE_M, EPI_TILE_N, STAGE) sC: cute.struct.Align[ @@ -794,11 +794,11 @@ class Sm100BlockScaledPersistentDenseGemmKernel: # Tensor memory dealloc barrier init tmem = utils.TmemAllocator( - storage.tmem_holding_buf, + storage.tmem_holding_buf.ptr, barrier_for_retrieve=self.tmem_alloc_barrier, allocator_warp_id=self.epilog_warp_id[0], is_two_cta=use_2cta_instrs, - two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr, ) # Cluster arrive after barrier init diff --git a/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent_amax.py b/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent_amax.py index f423fcced..c4c35f44a 100644 --- a/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent_amax.py +++ b/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent_amax.py @@ -551,7 +551,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel: ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] - tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_dealloc_mbar: cutlass.Int64 tmem_holding_buf: cutlass.Int32 # (EPI_TILE_M, EPI_TILE_N, STAGE) sC: cute.struct.Align[ @@ -737,11 +737,11 @@ class Sm100BlockScaledPersistentDenseGemmKernel: # Tensor memory dealloc barrier init tmem = utils.TmemAllocator( - storage.tmem_holding_buf, + storage.tmem_holding_buf.ptr, barrier_for_retrieve=self.tmem_alloc_barrier, allocator_warp_id=self.epilog_warp_id[0], is_two_cta=use_2cta_instrs, - two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr, ) # Cluster arrive after barrier init diff --git a/examples/python/CuTeDSL/blackwell/dense_gemm.py b/examples/python/CuTeDSL/blackwell/dense_gemm.py index 94495fb1a..6c6144bc8 100644 --- a/examples/python/CuTeDSL/blackwell/dense_gemm.py +++ b/examples/python/CuTeDSL/blackwell/dense_gemm.py @@ -517,7 +517,7 @@ class DenseGemmKernel: acc_full_mbar_ptr: cute.struct.MemRange[ cutlass.Int64, self.num_acc_stage * 2 ] - tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_dealloc_mbar: cutlass.Int64 tmem_holding_buf: cutlass.Int32 smem = utils.SmemAllocator() @@ -564,10 +564,10 @@ class DenseGemmKernel: ) # Tensor memory dealloc barrier init tmem = utils.TmemAllocator( - storage.tmem_holding_buf, + storage.tmem_holding_buf.ptr, barrier_for_retrieve=tmem_alloc_barrier, is_two_cta=use_2cta_instrs, - two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr, ) # Cluster arrive after barrier init diff --git a/examples/python/CuTeDSL/blackwell/dense_gemm_alpha_beta_persistent.py b/examples/python/CuTeDSL/blackwell/dense_gemm_alpha_beta_persistent.py index d88813250..9a81ba370 100644 --- a/examples/python/CuTeDSL/blackwell/dense_gemm_alpha_beta_persistent.py +++ b/examples/python/CuTeDSL/blackwell/dense_gemm_alpha_beta_persistent.py @@ -493,7 +493,7 @@ class SM100PersistentDenseGemmAlphaBetaKernel: acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] c_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_c_stage] c_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_c_stage] - tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_dealloc_mbar: cutlass.Int64 tmem_holding_buf: cutlass.Int32 # (EPI_TILE_M, EPI_TILE_N, STAGE) sD: cute.struct.Align[ @@ -674,11 +674,11 @@ class SM100PersistentDenseGemmAlphaBetaKernel: ) # Tensor memory dealloc barrier init tmem = utils.TmemAllocator( - storage.tmem_holding_buf, + storage.tmem_holding_buf.ptr, barrier_for_retrieve=tmem_alloc_barrier, allocator_warp_id=self.epilog_warp_ids[0], is_two_cta=use_2cta_instrs, - two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr, ) # Cluster arrive after barrier init diff --git a/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py b/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py index 731624b24..10d62d239 100644 --- a/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py +++ b/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py @@ -593,7 +593,7 @@ class PersistentDenseGemmKernel: acc_full_mbar_ptr: cute.struct.MemRange[ cutlass.Int64, self.num_acc_stage * 2 ] - tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_dealloc_mbar: cutlass.Int64 tmem_holding_buf: cutlass.Int32 smem = utils.SmemAllocator() @@ -644,11 +644,11 @@ class PersistentDenseGemmKernel: ) # Tensor memory dealloc barrier init tmem = utils.TmemAllocator( - storage.tmem_holding_buf, + storage.tmem_holding_buf.ptr, barrier_for_retrieve=tmem_alloc_barrier, allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta_instrs, - two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr, ) # Cluster arrive after barrier init diff --git a/examples/python/CuTeDSL/blackwell/dense_gemm_persistent_dynamic.py b/examples/python/CuTeDSL/blackwell/dense_gemm_persistent_dynamic.py index e16eba200..d4ce020ed 100644 --- a/examples/python/CuTeDSL/blackwell/dense_gemm_persistent_dynamic.py +++ b/examples/python/CuTeDSL/blackwell/dense_gemm_persistent_dynamic.py @@ -614,7 +614,7 @@ class PersistentDenseGemmKernel: acc_full_mbar_ptr: cute.struct.MemRange[ cutlass.Int64, self.num_acc_stage * 2 ] - tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_dealloc_mbar: cutlass.Int64 tmem_holding_buf: cutlass.Int32 clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2] clc_response: cute.struct.MemRange[cutlass.Int32, 4] @@ -686,11 +686,11 @@ class PersistentDenseGemmKernel: ) # Tensor memory dealloc barrier init tmem = utils.TmemAllocator( - storage.tmem_holding_buf, + storage.tmem_holding_buf.ptr, barrier_for_retrieve=tmem_alloc_barrier, allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta_instrs, - two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr, ) # Cluster arrive after barrier init diff --git a/examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py b/examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py index 3622b2114..46e16770f 100644 --- a/examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py +++ b/examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py @@ -514,7 +514,7 @@ class DenseGemmKernel: acc_full_mbar_ptr: cute.struct.MemRange[ cutlass.Int64, self.num_acc_stage * 2 ] - tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_dealloc_mbar: cutlass.Int64 tmem_holding_buf: cutlass.Int32 smem = utils.SmemAllocator() @@ -562,10 +562,10 @@ class DenseGemmKernel: ) # Tensor memory dealloc barrier init tmem = utils.TmemAllocator( - storage.tmem_holding_buf, + storage.tmem_holding_buf.ptr, barrier_for_retrieve=tmem_alloc_barrier, is_two_cta=use_2cta_instrs, - two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr, ) # Cluster arrive after barrier init diff --git a/examples/python/CuTeDSL/blackwell/epilogue/common_dense_gemm_efc.py b/examples/python/CuTeDSL/blackwell/epilogue/common_dense_gemm_efc.py index 0a4fa2d7b..c4c72e1bb 100644 --- a/examples/python/CuTeDSL/blackwell/epilogue/common_dense_gemm_efc.py +++ b/examples/python/CuTeDSL/blackwell/epilogue/common_dense_gemm_efc.py @@ -573,7 +573,7 @@ class DenseGemmEFC: # Barriers used by the supplemental load tensor pipeline. c_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_c_stage] c_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_c_stage] - tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_dealloc_mbar: cutlass.Int64 tmem_holding_buf: cutlass.Int32 # (MMA, MMA_M, MMA_K, STAGE) sA: cute.struct.Align[ @@ -651,11 +651,11 @@ class DenseGemmEFC: ) # Tensor memory dealloc barrier init tmem = utils.TmemAllocator( - storage.tmem_holding_buf, + storage.tmem_holding_buf.ptr, barrier_for_retrieve=tmem_alloc_barrier, allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=self.use_2cta_instrs, - two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr, ) # Cluster arrive after barrier init diff --git a/examples/python/CuTeDSL/blackwell/fmha.py b/examples/python/CuTeDSL/blackwell/fmha.py index bbeb6a610..5d4ddd7b4 100644 --- a/examples/python/CuTeDSL/blackwell/fmha.py +++ b/examples/python/CuTeDSL/blackwell/fmha.py @@ -998,7 +998,7 @@ class BlackwellFusedMultiHeadAttentionForward: # Alloc tmem buffer tmem_alloc_cols = Int32(self.tmem_alloc_cols) - cute.arch.alloc_tmem(tmem_alloc_cols, storage.tmem_holding_buf) + cute.arch.alloc_tmem(tmem_alloc_cols, storage.tmem_holding_buf.ptr) self.tmem_alloc_barrier.arrive_and_wait() tile_sched = fmha_utils.create_fmha_static_tile_scheduler( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() @@ -1260,7 +1260,7 @@ class BlackwellFusedMultiHeadAttentionForward: tmem_ptr = cute.arch.retrieve_tmem_ptr( Float32, alignment=16, - ptr_to_buffer_holding_addr=storage.tmem_holding_buf, + ptr_to_buffer_holding_addr=storage.tmem_holding_buf.ptr, ) cute.arch.dealloc_tmem(tmem_ptr, tmem_alloc_cols) diff --git a/examples/python/CuTeDSL/blackwell/fmha_bwd.py b/examples/python/CuTeDSL/blackwell/fmha_bwd.py index 5725d5e8a..9d415423a 100644 --- a/examples/python/CuTeDSL/blackwell/fmha_bwd.py +++ b/examples/python/CuTeDSL/blackwell/fmha_bwd.py @@ -708,7 +708,6 @@ class BlackwellFusedMultiHeadAttentionBackward: grid=bwd_grid, block=[self.threads_per_cta, 1, 1], cluster=[1, 1, 1], - smem=self.shared_storage.size_in_bytes(), stream=stream, min_blocks_per_mp=1, ) @@ -913,7 +912,7 @@ class BlackwellFusedMultiHeadAttentionBackward: ) sLSE = storage.sLSE.get_tensor(LSE_smem_layout) sSum_OdO = storage.sSum_OdO.get_tensor(sum_OdO_smem_layout) - tmem_holding_buf = storage.tmem_holding_buf + tmem_holding_buf = storage.tmem_holding_buf.ptr sQT_ptr = cute.recast_ptr(sQ.iterator, QT_smem_layout_staged.inner) sQT = cute.make_tensor(sQT_ptr, QT_smem_layout_staged.outer) diff --git a/examples/python/CuTeDSL/blackwell/grouped_blockscaled_gemm.py b/examples/python/CuTeDSL/blackwell/grouped_blockscaled_gemm.py index 0f2dd5fad..1f6d65860 100644 --- a/examples/python/CuTeDSL/blackwell/grouped_blockscaled_gemm.py +++ b/examples/python/CuTeDSL/blackwell/grouped_blockscaled_gemm.py @@ -567,7 +567,7 @@ class Sm100GroupedBlockScaledGemmKernel: ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] - tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_dealloc_mbar: cutlass.Int64 tmem_holding_buf: cutlass.Int32 # (EPI_TILE_M, EPI_TILE_N, STAGE) sC: cute.struct.Align[ @@ -641,7 +641,6 @@ class Sm100GroupedBlockScaledGemmKernel: grid=grid, block=[self.threads_per_cta, 1, 1], cluster=(*self.cluster_shape_mn, 1), - smem=self.shared_storage.size_in_bytes(), stream=stream, min_blocks_per_mp=1, ) @@ -737,8 +736,8 @@ class Sm100GroupedBlockScaledGemmKernel: + Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap // 8 ) - tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr - tmem_holding_buf = storage.tmem_holding_buf + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar.ptr + tmem_holding_buf_ptr = storage.tmem_holding_buf.ptr # Initialize mainloop ab_pipeline (barrier) and states ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) @@ -1249,7 +1248,7 @@ class Sm100GroupedBlockScaledGemmKernel: acc_tmem_ptr = cute.arch.retrieve_tmem_ptr( self.acc_dtype, alignment=16, - ptr_to_buffer_holding_addr=tmem_holding_buf, + ptr_to_buffer_holding_addr=tmem_holding_buf_ptr, ) # (MMA, MMA_M, MMA_N, STAGE) tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) @@ -1446,7 +1445,7 @@ class Sm100GroupedBlockScaledGemmKernel: if warp_idx == self.epilog_warp_id[0]: cute.arch.alloc_tmem( self.num_tmem_alloc_cols, - tmem_holding_buf, + tmem_holding_buf_ptr, is_two_cta=use_2cta_instrs, ) @@ -1461,7 +1460,7 @@ class Sm100GroupedBlockScaledGemmKernel: acc_tmem_ptr = cute.arch.retrieve_tmem_ptr( self.acc_dtype, alignment=16, - ptr_to_buffer_holding_addr=tmem_holding_buf, + ptr_to_buffer_holding_addr=tmem_holding_buf_ptr, ) # (MMA, MMA_M, MMA_N, STAGE) tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) diff --git a/examples/python/CuTeDSL/blackwell/grouped_gemm.py b/examples/python/CuTeDSL/blackwell/grouped_gemm.py index 50c208f20..e4c28f143 100644 --- a/examples/python/CuTeDSL/blackwell/grouped_gemm.py +++ b/examples/python/CuTeDSL/blackwell/grouped_gemm.py @@ -425,7 +425,7 @@ class GroupedGemmKernel: ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] - tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_dealloc_mbar: cutlass.Int64 tmem_holding_buf: cutlass.Int32 # (EPI_TILE_M, EPI_TILE_N, STAGE) sC: cute.struct.Align[ @@ -590,11 +590,11 @@ class GroupedGemmKernel: ) # Tensor memory dealloc barrier init tmem = utils.TmemAllocator( - storage.tmem_holding_buf, + storage.tmem_holding_buf.ptr, barrier_for_retrieve=self.tmem_alloc_barrier, allocator_warp_id=self.epilog_warp_id[0], is_two_cta=use_2cta_instrs, - two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr, ) # Cluster arrive after barrier init diff --git a/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py b/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py index 630de111f..bd709f837 100644 --- a/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py +++ b/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py @@ -820,7 +820,7 @@ class SSDKernel: num_threads=self.threads_per_cta, ) tmem = utils.TmemAllocator( - smem_storage.tmem_holding_buf, + smem_storage.tmem_holding_buf.ptr, barrier_for_retrieve=tmem_alloc_barrier, allocator_warp_id=self.epilog_warp_id[0], ) diff --git a/examples/python/CuTeDSL/blackwell/mixed_input_fmha/mixed_input_fmha_decode.py b/examples/python/CuTeDSL/blackwell/mixed_input_fmha/mixed_input_fmha_decode.py index 0fd5d3652..579b3cd0e 100644 --- a/examples/python/CuTeDSL/blackwell/mixed_input_fmha/mixed_input_fmha_decode.py +++ b/examples/python/CuTeDSL/blackwell/mixed_input_fmha/mixed_input_fmha_decode.py @@ -603,8 +603,6 @@ class MixedInputFusedMultiHeadAttentionDecode: p_pipeline_ptr = smem.allocate_array(Int64, self.sp_stages * 2) o_pipeline_ptr = smem.allocate_array(Int64, self.o_stages * 2) - assert smem._allocated_bytes <= self.mbarrier_reserved_bytes - # Declare named barriers softmax_nbar_id = 1 mma_kq_nbar_id = 2 diff --git a/examples/python/CuTeDSL/blackwell/mixed_input_fmha/mixed_input_fmha_prefill_d256.py b/examples/python/CuTeDSL/blackwell/mixed_input_fmha/mixed_input_fmha_prefill_d256.py index aaaff2ac4..07a2a73da 100644 --- a/examples/python/CuTeDSL/blackwell/mixed_input_fmha/mixed_input_fmha_prefill_d256.py +++ b/examples/python/CuTeDSL/blackwell/mixed_input_fmha/mixed_input_fmha_prefill_d256.py @@ -403,7 +403,7 @@ class MixedInputFusedMultiHeadAttentionPrefillD256: s_corr_mbar_ptr: cute.struct.MemRange[Int64, self.qk_acc_stage * 2] sum_mbar_ptr: cute.struct.MemRange[Int64, 2] mma_o_mbar_ptr: cute.struct.MemRange[Int64, self.pv_acc_stage * 2] - tmem_dealloc_mbar_ptr: Int64 + tmem_dealloc_mbar: Int64 tmem_holding_buf: Int32 self.shared_storage = SharedStorage @@ -654,11 +654,11 @@ class MixedInputFusedMultiHeadAttentionPrefillD256: ) # Tensor memory dealloc barrier init tmem = utils.TmemAllocator( - storage.tmem_holding_buf, + storage.tmem_holding_buf.ptr, barrier_for_retrieve=tmem_alloc_barrier, allocator_warp_id=self.correction_warp_ids[0], is_two_cta=True, - two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr, ) # Cluster arrive after barrier init pipeline_init_arrive(cluster_shape_mn=cluster_layout_vmnk, is_relaxed=True) diff --git a/examples/python/CuTeDSL/blackwell/mixed_input_fmha/mixed_input_fmha_prefill_d512.py b/examples/python/CuTeDSL/blackwell/mixed_input_fmha/mixed_input_fmha_prefill_d512.py index 874ffc495..89713fef5 100644 --- a/examples/python/CuTeDSL/blackwell/mixed_input_fmha/mixed_input_fmha_prefill_d512.py +++ b/examples/python/CuTeDSL/blackwell/mixed_input_fmha/mixed_input_fmha_prefill_d512.py @@ -390,7 +390,7 @@ class MixedInputFusedMultiHeadAttentionPrefillD512: p_mma_mbar_ptr: cute.struct.MemRange[Int64, self.qk_acc_stage * 2] mma_o_mbar_ptr: cute.struct.MemRange[Int64, self.pv_acc_stage * 2] swap_mbar_ptr: cute.struct.MemRange[Int64, self.swap_stage * 2] - tmem_dealloc_mbar_ptr: Int64 + tmem_dealloc_mbar: Int64 tmem_holding_buf: Int32 self.shared_storage = SharedStorage @@ -627,11 +627,11 @@ class MixedInputFusedMultiHeadAttentionPrefillD512: ) # Tensor memory dealloc barrier init tmem = utils.TmemAllocator( - storage.tmem_holding_buf, + storage.tmem_holding_buf.ptr, barrier_for_retrieve=tmem_alloc_barrier, allocator_warp_id=self.softmax_warp_ids[0], is_two_cta=True, - two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr, ) # Cluster arrive after barrier init pipeline_init_arrive(cluster_shape_mn=cluster_layout_vmnk, is_relaxed=True) diff --git a/examples/python/CuTeDSL/blackwell/mixed_input_gemm/grouped_mixed_input_gemm.py b/examples/python/CuTeDSL/blackwell/mixed_input_gemm/grouped_mixed_input_gemm.py index a81d1e393..acab8a761 100644 --- a/examples/python/CuTeDSL/blackwell/mixed_input_gemm/grouped_mixed_input_gemm.py +++ b/examples/python/CuTeDSL/blackwell/mixed_input_gemm/grouped_mixed_input_gemm.py @@ -624,7 +624,7 @@ class GroupedMixedInputGemmKernel: tile_info_empty_mbar_ptr: cute.struct.MemRange[ cutlass.Int64, self.num_tile_info_stage ] - tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_dealloc_mbar: cutlass.Int64 tmem_holding_buf: cutlass.Int32 self.shared_storage = SharedStorage @@ -824,11 +824,11 @@ class GroupedMixedInputGemmKernel: # Tensor memory dealloc barrier init tmem = utils.TmemAllocator( - storage.tmem_holding_buf, + storage.tmem_holding_buf.ptr, barrier_for_retrieve=self.tmem_ptr_sync_barrier, allocator_warp_id=self.epilog_warp_id[0], is_two_cta=use_2cta_instrs, - two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr, ) # Cluster arrive after barrier init diff --git a/examples/python/CuTeDSL/blackwell/mixed_input_gemm/grouped_mixed_input_gemm_acc_scale.py b/examples/python/CuTeDSL/blackwell/mixed_input_gemm/grouped_mixed_input_gemm_acc_scale.py index 472816b6e..43322e283 100644 --- a/examples/python/CuTeDSL/blackwell/mixed_input_gemm/grouped_mixed_input_gemm_acc_scale.py +++ b/examples/python/CuTeDSL/blackwell/mixed_input_gemm/grouped_mixed_input_gemm_acc_scale.py @@ -518,7 +518,7 @@ class GroupedMixedInputGemmAccScaleKernel: tile_info_empty_mbar_ptr: cute.struct.MemRange[ cutlass.Int64, self.num_tile_info_stage ] - tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_dealloc_mbar: cutlass.Int64 tmem_holding_buf: cutlass.Int32 self.shared_storage = SharedStorage @@ -708,11 +708,11 @@ class GroupedMixedInputGemmAccScaleKernel: # Tensor memory dealloc barrier init tmem = utils.TmemAllocator( - storage.tmem_holding_buf, + storage.tmem_holding_buf.ptr, barrier_for_retrieve=self.tmem_ptr_sync_barrier, allocator_warp_id=self.epilog_warp_id[0], is_two_cta=use_2cta_instrs, - two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr, ) # Cluster arrive after barrier init diff --git a/examples/python/CuTeDSL/blackwell/mixed_input_gemm/mixed_input_gemm.py b/examples/python/CuTeDSL/blackwell/mixed_input_gemm/mixed_input_gemm.py index 2517e69ae..ac0dd51f4 100644 --- a/examples/python/CuTeDSL/blackwell/mixed_input_gemm/mixed_input_gemm.py +++ b/examples/python/CuTeDSL/blackwell/mixed_input_gemm/mixed_input_gemm.py @@ -618,7 +618,7 @@ class MixedInputGemmKernel: ] acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] - tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_dealloc_mbar: cutlass.Int64 tmem_holding_buf: cutlass.Int32 # Tensor buffers # (EPI_TILE_M, EPI_TILE_N, STAGE) @@ -820,11 +820,11 @@ class MixedInputGemmKernel: # Tensor memory dealloc barrier init tmem = utils.TmemAllocator( - storage.tmem_holding_buf, + storage.tmem_holding_buf.ptr, barrier_for_retrieve=self.tmem_ptr_sync_barrier, allocator_warp_id=self.epilog_warp_id[0], is_two_cta=use_2cta_instrs, - two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr, ) # Cluster arrive after barrier init diff --git a/examples/python/CuTeDSL/blackwell/mla/mla_decode_fp16.py b/examples/python/CuTeDSL/blackwell/mla/mla_decode_fp16.py index 7eba992bd..58e4a7f85 100644 --- a/examples/python/CuTeDSL/blackwell/mla/mla_decode_fp16.py +++ b/examples/python/CuTeDSL/blackwell/mla/mla_decode_fp16.py @@ -568,7 +568,7 @@ class BlackwellMultiHeadLatentAttentionForwardFP16: cutlass.Int64, self.load_pt_stage * 2 ] # Tmem dealloc cluster barrier - tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_dealloc_mbar: cutlass.Int64 # Tmem holding buffer tmem_holding_buf: cutlass.Int32 @@ -641,7 +641,6 @@ class BlackwellMultiHeadLatentAttentionForwardFP16: grid=grid, block=[self.threads_per_cta, 1, 1], cluster=self.cluster_shape_mnk, - smem=SplitKVKernelSharedStorage.size_in_bytes(), stream=stream, min_blocks_per_mp=1, ) @@ -657,7 +656,6 @@ class BlackwellMultiHeadLatentAttentionForwardFP16: ).launch( grid=(q_latent.shape[0], q_latent.shape[2], q_latent.shape[3]), block=[self.threads_per_warp * self.num_compute_warps, 1, 1], - smem=MAX_SPLITS * self.acc_dtype.width // 8, stream=stream, min_blocks_per_mp=1, ) @@ -838,11 +836,11 @@ class BlackwellMultiHeadLatentAttentionForwardFP16: # Tensor memory dealloc barrier init tmem = utils.TmemAllocator( - storage.tmem_holding_buf, + storage.tmem_holding_buf.ptr, barrier_for_retrieve=self.tmem_ptr_sync_bar, allocator_warp_id=self.mma_warp_id, is_two_cta=self.use_2cta_instrs, - two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr, ) load_q_pipeline = self.make_and_init_load_qkv_pipeline( diff --git a/examples/python/CuTeDSL/blackwell/mla/mla_decode_fp8.py b/examples/python/CuTeDSL/blackwell/mla/mla_decode_fp8.py index e6383ef82..2ae63af0f 100644 --- a/examples/python/CuTeDSL/blackwell/mla/mla_decode_fp8.py +++ b/examples/python/CuTeDSL/blackwell/mla/mla_decode_fp8.py @@ -661,7 +661,7 @@ class BlackwellMultiHeadLatentAttentionForwardFP8: ] # Tmem dealloc cluster barrier - tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_dealloc_mbar: cutlass.Int64 # Tmem holding buffer tmem_holding_buf: cutlass.Int32 @@ -707,7 +707,6 @@ class BlackwellMultiHeadLatentAttentionForwardFP8: grid=grid, block=[self.threads_per_cta, 1, 1], cluster=self.cluster_shape_mnk, - smem=SplitKVKernelSharedStorage.size_in_bytes(), stream=stream, min_blocks_per_mp=1, ) @@ -723,7 +722,6 @@ class BlackwellMultiHeadLatentAttentionForwardFP8: ).launch( grid=(q_latent.shape[0], q_latent.shape[2], q_latent.shape[3]), block=[self.threads_per_warp * self.num_compute_warps, 1, 1], - smem=MAX_SPLITS * self.acc_dtype.width // 8, stream=stream, min_blocks_per_mp=1, ) @@ -904,11 +902,11 @@ class BlackwellMultiHeadLatentAttentionForwardFP8: # Tensor memory dealloc barrier init tmem = utils.TmemAllocator( - storage.tmem_holding_buf, + storage.tmem_holding_buf.ptr, barrier_for_retrieve=self.tmem_ptr_sync_bar, allocator_warp_id=self.mma_warp_id, is_two_cta=self.use_2cta_instrs, - two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr, ) load_q_pipeline = self.make_and_init_load_qkv_pipeline( diff --git a/examples/python/CuTeDSL/blackwell/sm103_dense_blockscaled_gemm_persistent.py b/examples/python/CuTeDSL/blackwell/sm103_dense_blockscaled_gemm_persistent.py index 39f148486..971165374 100644 --- a/examples/python/CuTeDSL/blackwell/sm103_dense_blockscaled_gemm_persistent.py +++ b/examples/python/CuTeDSL/blackwell/sm103_dense_blockscaled_gemm_persistent.py @@ -43,9 +43,9 @@ from cutlass.cute.runtime import from_dlpack from dataclasses import dataclass, field """ -This example provides an experimental implementation of the SM103 batched 3xFP4 blockscaled GEMM kernel, please note that the APIs and implementation details related to this kernel may change in future releases. +This example provides an experimental implementation of the SM103 batched FP4 Ultra blockscaled GEMM kernel, please note that the APIs and implementation details related to this kernel may change in future releases. -A high-performance persistent batched 3xFP4 blockscaled GEMM example for the NVIDIA Blackwell SM103 architecture +A high-performance persistent batched FP4 Ultra blockscaled GEMM example for the NVIDIA Blackwell SM103 architecture using CUTE DSL. - Matrix A is MxKxL, L is batch dimension, A can only be row-major("K") for MXF4/NVF4 input type - Matrix B is NxKxL, L is batch dimension, B can only be row-major("K") for MXF4/NVF4 input type @@ -166,7 +166,7 @@ class Sm103BlockScaledPersistentDenseGemmKernel: cluster_shape_mn: Tuple[int, int], use_tma_store: bool, ): - """Initializes the configuration for a Blackwell SM103 3xFP4 GEMM kernel. + """Initializes the configuration for a Blackwell SM103 FP4 Ultra GEMM kernel. This configuration includes several key aspects: @@ -603,7 +603,7 @@ class Sm103BlockScaledPersistentDenseGemmKernel: sf_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_sf_stage] acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] - tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_dealloc_mbar: cutlass.Int64 tmem_holding_buf: cutlass.Int32 # (MMA, MMA_M, MMA_K, STAGE) sA: cute.struct.Align[ @@ -800,11 +800,11 @@ class Sm103BlockScaledPersistentDenseGemmKernel: ) # Tensor memory dealloc barrier init tmem = utils.TmemAllocator( - storage.tmem_holding_buf, + storage.tmem_holding_buf.ptr, barrier_for_retrieve=tmem_alloc_barrier, allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta_instrs, - two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr, ) # Cluster arrive after barrier init @@ -1810,7 +1810,7 @@ class Sm103BlockScaledPersistentDenseGemmKernel: mma_tiler_mn: Tuple[int, int], a_source: tcgen05.OperandSource = tcgen05.OperandSource.SMEM, ) -> cute.TiledMma: - """Create a blockscaled trivial tiled MMA for SM103 (3xFP4), K fixed to 96. + """Create a blockscaled trivial tiled MMA for SM103 (FP4 Ultra), K fixed to 96. Returns a tcgen05 MMA configured for the given (M, N) tiler and CTA group. @@ -2653,7 +2653,7 @@ def run( :return: Execution time of the GEMM kernel :rtype: float """ - print(f"Running Sm103 Persistent 3xfp4 Dense BlockScaled GEMM test with:") + print(f"Running Sm103 Persistent FP4 Ultra Dense BlockScaled GEMM test with:") print(f"mnkl: {mnkl}") print(f"AB dtype: {ab_dtype}, SF dtype: {sf_dtype}, SF Vec size: {sf_vec_size}") print(f"C dtype: {c_dtype}") @@ -2954,7 +2954,7 @@ if __name__ == "__main__": ) parser = argparse.ArgumentParser( - description="Example of Sm103 3xfp4 Dense Persistent BlockScaled GEMM." + description="Example of Sm103 FP4 Ultra Dense Persistent BlockScaled GEMM." ) parser.add_argument( diff --git a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_0.py b/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_0.py index d953be1ed..24c406047 100644 --- a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_0.py +++ b/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_0.py @@ -8,13 +8,10 @@ # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. - import argparse -from typing import Tuple, Type, Callable -from functools import partial, lru_cache +from typing import Tuple import cutlass -from cutlass import Numeric import cutlass.cute as cute import cutlass.utils as utils import cutlass.pipeline as pipeline @@ -69,7 +66,6 @@ def kernel( a_smem_layout: cute.ComposedLayout, b_smem_layout: cute.ComposedLayout, ): - # Current thread/warp/block coordinates tidx, _, _ = cute.arch.thread_idx() warp_idx = cute.arch.warp_idx() @@ -103,7 +99,7 @@ def kernel( num_threads=threads_per_cta, ) tmem = utils.TmemAllocator( - storage.tmem_holding_buf, + storage.tmem_holding_buf.ptr, barrier_for_retrieve=tmem_alloc_barrier, ) num_tmem_cols = 512 @@ -143,15 +139,15 @@ def kernel( # (bM, bN) gC = cute.local_tile(mC_mnl, mma_tiler_mnk, mma_coord_mnk, proj=(1, 1, None)) thr_mma = tiled_mma.get_slice(0) - # (MMA, MMA_M, MMA_K, RestK) + # (MMA, MMA_M, MMA_K) tCgA = thr_mma.partition_A(gA) - # (MMA, MMA_N, MMA_K, RestK) + # (MMA, MMA_N, MMA_K) tCgB = thr_mma.partition_B(gB) # (MMA, MMA_M, MMA_N) tCgC = thr_mma.partition_C(gC) - # (MMA, MMA_M, MMA_K, STAGE) + # (MMA, MMA_M, MMA_K) tCrA = tiled_mma.make_fragment_A(sA) - # (MMA, MMA_N, MMA_K, STAGE) + # (MMA, MMA_N, MMA_K) tCrB = tiled_mma.make_fragment_B(sB) # (MMA, MMA_M, MMA_N) acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2]) @@ -199,14 +195,14 @@ def kernel( tmem_thr_copy = tmem_tiled_copy.get_slice(tidx) # (TmemCpy,NumTmemCpy,NumTiles) - tCtC = tmem_thr_copy.partition_S(tCtAcc_epi) + tDtC = tmem_thr_copy.partition_S(tCtAcc_epi) # (TmemCpy,NumTmemCpy,NumTiles) - tCgC = tmem_thr_copy.partition_D(gC_epi) + tDgC = tmem_thr_copy.partition_D(gC_epi) # (TmemCpy,NumTmemCpy) - tCrAcc = cute.make_rmem_tensor(tCgC[None, None, 0].shape, acc_dtype) + tCrAcc = cute.make_rmem_tensor(tDgC[None, None, 0].shape, acc_dtype) # (TmemCpy,NumTmemCpy) - tCrC = cute.make_rmem_tensor(tCgC[None, None, 0].shape, io_dtype) + tCrC = cute.make_rmem_tensor(tDgC[None, None, 0].shape, io_dtype) # # 2. Main loop @@ -233,8 +229,6 @@ def kernel( # Execute one K-block worth of MMA instructions ab_full = ab_consumer.wait_and_advance() - - # tCtAcc += tCrA * tCrB num_k_blocks = cute.size(tCrA, mode=[2]) for k_block_idx in cutlass.range_constexpr(num_k_blocks): k_block_coord = (None, None, k_block_idx, ab_full.index) @@ -265,10 +259,10 @@ def kernel( # TMEM -> RMEM -> GEMM # Sub-tiling for better instruction-level parallelism - for i in cutlass.range(cute.size(tCtC, mode=[2])): - cute.copy(tmem_tiled_copy, tCtC[None, None, i], tCrAcc) + for i in cutlass.range(cute.size(tDtC, mode=[2])): + cute.copy(tmem_tiled_copy, tDtC[None, None, i], tCrAcc) tCrC.store(tCrAcc.load().to(io_dtype)) - cute.autovec_copy(tCrC, tCgC[None, None, i]) + cute.autovec_copy(tCrC, tDgC[None, None, i]) acc_full.release() # Deallocate TMEM @@ -350,44 +344,10 @@ def host_function(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor): ) -@lru_cache(maxsize=1) -def prepare_run( - callable: Callable, - m: int, - n: int, - k: int, - a_dtype: Type[Numeric], - b_dtype: Type[Numeric], - c_dtype: Type[Numeric], -) -> tuple[Callable, tuple]: - import cutlass.torch as cutlass_torch - - a, b, c = cutlass_torch.prepare_tensors_for_gemm( - (m, n, k), a_dtype, b_dtype, c_dtype - ) - a_ = ( - from_dlpack(a, assumed_align=32) - .mark_layout_dynamic(leading_dim=1) - .mark_compact_shape_dynamic(mode=1, divisibility=k) - ) - b_ = ( - from_dlpack(b, assumed_align=32) - .mark_layout_dynamic(leading_dim=1) - .mark_compact_shape_dynamic(mode=1, divisibility=k) - ) - c_ = ( - from_dlpack(c, assumed_align=32) - .mark_layout_dynamic(leading_dim=1) - .mark_compact_shape_dynamic(mode=1, divisibility=n) - ) - compiled_fn = cute.compile(callable, a_, b_, c_, options="--generate-line-info") - return partial(compiled_fn, a_, b_, c_), (a, b, c) - - def run_dense_gemm( mnk: Tuple[int, int, int], tolerance: float, -) -> None: +): global torch, cutlass_torch import torch import cutlass.torch as cutlass_torch @@ -402,23 +362,48 @@ def run_dense_gemm( m, n, k = mnk torch.manual_seed(1111) - run_fn, (a, b, c) = prepare_run( - host_function, m, n, k, io_dtype, io_dtype, io_dtype + # Make K-major tensors (torch tensors are row-major) + def make_tensors(mn, k, dtype): + shape = (mn, k) + return ( + torch.empty(*shape, dtype=torch.int32) + .random_(-2, 2) + .to(dtype=dtype, device="cuda") + ) + + a = make_tensors(m, k, cutlass_torch.dtype(io_dtype)) + b = make_tensors(n, k, cutlass_torch.dtype(io_dtype)) + c = make_tensors(m, n, cutlass_torch.dtype(io_dtype)) + a_tensor = ( + from_dlpack(a, assumed_align=32) + .mark_layout_dynamic(leading_dim=1) + .mark_compact_shape_dynamic(mode=1, divisibility=k) ) + b_tensor = ( + from_dlpack(b, assumed_align=32) + .mark_layout_dynamic(leading_dim=1) + .mark_compact_shape_dynamic(mode=1, divisibility=k) + ) + c_tensor = ( + from_dlpack(c, assumed_align=32) + .mark_layout_dynamic(leading_dim=1) + .mark_compact_shape_dynamic(mode=1, divisibility=n) + ) + # Entry point to the host JIT function - run_fn() + host_function(a_tensor, b_tensor, c_tensor, no_cache=True) # Compute reference result and verify - ref = torch.einsum("mk,nk->mn", a.to(torch.float32), b.to(torch.float32)) + ref = (torch.einsum("mk,nk->mn", a.to(torch.float32), b.to(torch.float32))).cpu() torch.testing.assert_close( - c, ref.to(cutlass_torch.dtype(io_dtype)), atol=tolerance, rtol=1e-05 + c.cpu(), ref.to(cutlass_torch.dtype(io_dtype)), atol=tolerance, rtol=1e-05 ) if __name__ == "__main__": - def parse_comma_separated_ints(s: str) -> list[int]: + def parse_comma_separated_ints(s: str): try: return [int(x.strip()) for x in s.split(",")] except ValueError: @@ -443,14 +428,14 @@ if __name__ == "__main__": parser.add_argument( "--tolerance", type=float, default=1e-01, help="Tolerance for validation" ) - args = parser.parse_args() - if len(args.mnk) != 3: parser.error("--mnk must contain exactly 3 values") if args.mnk[0] % mma_tiler_mnk[0] != 0 or args.mnk[1] % mma_tiler_mnk[1] != 0: parser.error("m n must be divisible by mma_tiler_mn") - run_dense_gemm(args.mnk, args.tolerance) + run_dense_gemm( + args.mnk, + args.tolerance, + ) print("PASS") - diff --git a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_1.py b/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_1.py index f4abf4f04..63ce94c65 100644 --- a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_1.py +++ b/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_1.py @@ -65,8 +65,7 @@ Constraints for this example: io_dtype = cutlass.Float16 acc_dtype = cutlass.Float32 -use_2cta_instrs = True -cluster_shape_mnk = (2, 1, 1) if use_2cta_instrs else (1, 1, 1) +cluster_shape_mnk = (2, 1, 1) mma_inst_shape_mnk = (256, 256, 16) mma_tiler_mnk = (256, 256, 64) threads_per_cta = 128 @@ -96,7 +95,6 @@ def kernel( b_smem_layout: cute.ComposedLayout, cta_layout_vmnk: cute.Layout, ): - # Current thread/warp/block coordinates tidx, _, _ = cute.arch.thread_idx() warp_idx = cute.arch.warp_idx() @@ -174,15 +172,15 @@ def kernel( # (bM, bN) gC = cute.local_tile(mC_mnl, mma_tiler_mnk, mma_coord_mnk, proj=(1, 1, None)) thr_mma = tiled_mma.get_slice(mma_coord_vmnk[0]) - # (MMA, MMA_M, MMA_K, RestK) + # (MMA, MMA_M, MMA_K) tCgA = thr_mma.partition_A(gA) - # (MMA, MMA_N, MMA_K, RestK) + # (MMA, MMA_N, MMA_K) tCgB = thr_mma.partition_B(gB) # (MMA, MMA_M, MMA_N) tCgC = thr_mma.partition_C(gC) - # (MMA, MMA_M, MMA_K, STAGE) + # (MMA, MMA_M, MMA_K) tCrA = tiled_mma.make_fragment_A(sA) - # (MMA, MMA_N, MMA_K, STAGE) + # (MMA, MMA_N, MMA_K) tCrB = tiled_mma.make_fragment_B(sB) # (MMA, MMA_M, MMA_N) acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2]) @@ -217,10 +215,10 @@ def kernel( num_threads=threads_per_cta, ) tmem = utils.TmemAllocator( - storage.tmem_holding_buf, + storage.tmem_holding_buf.ptr, barrier_for_retrieve=tmem_alloc_barrier, is_two_cta=cute.size(cta_layout_vmnk, mode=[0]) > 1, - two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr, ) num_tmem_cols = 512 tmem.allocate(num_tmem_cols) @@ -232,7 +230,7 @@ def kernel( # Swap the pointer in tCtAcc tCtAcc = cute.make_tensor(tmem_ptr, tCtAcc.layout) - subtile_cnt = 1 if mma_tiler_mnk[0] == 64 else 4 + subtile_cnt = 4 # (EpiTile) epi_tiler = ( (cute.size(tCtAcc, mode=[0, 0]), cute.size(tCtAcc, mode=[0, 1]) // subtile_cnt), @@ -244,24 +242,21 @@ def kernel( # Every thread loads 64 x fp32 tmem_atom = cute.make_copy_atom( - tcgen05.Ld16x256bOp(tcgen05.Repetition.x8) - if mma_tiler_mnk[0] == 64 - else tcgen05.Ld32x32bOp(tcgen05.Repetition.x64), + tcgen05.Ld32x32bOp(tcgen05.Repetition.x64), cutlass.Float32, ) - tmem_tiled_copy = tcgen05.make_tmem_copy(tmem_atom, tCtAcc_epi[None, 0]) tmem_thr_copy = tmem_tiled_copy.get_slice(tidx) # (TmemCpy,NumTmemCpy,NumTiles) - tCtC = tmem_thr_copy.partition_S(tCtAcc_epi) + tDtC = tmem_thr_copy.partition_S(tCtAcc_epi) # (TmemCpy,NumTmemCpy,NumTiles) - tCgC = tmem_thr_copy.partition_D(gC_epi) + tDgC = tmem_thr_copy.partition_D(gC_epi) # (TmemCpy,NumTmemCpy) - tCrAcc = cute.make_rmem_tensor(tCgC[None, None, 0].shape, acc_dtype) + tCrAcc = cute.make_rmem_tensor(tDgC[None, None, 0].shape, acc_dtype) # (TmemCpy,NumTmemCpy) - tCrC = cute.make_rmem_tensor(tCgC[None, None, 0].shape, io_dtype) + tCrC = cute.make_rmem_tensor(tDgC[None, None, 0].shape, io_dtype) # # 2. Main loop @@ -271,8 +266,8 @@ def kernel( if warp_idx == 0: # Wait for a empty accumulator buffer if is_leader_cta: - acc_producer.acquire() - for k_tile in cutlass.range(num_k_tiles, prefetch_stages=ab_stages - 2): + acc_producer.acquire_and_advance() + for _ in cutlass.range(num_k_tiles, prefetch_stages=ab_stages - 2): # Issue TMA loads ab_empty = ab_producer.acquire_and_advance() cute.copy( @@ -310,7 +305,6 @@ def kernel( # Signal that the accumulator is fully computed if is_leader_cta: acc_producer.commit() - acc_producer.advance() # # 3. Epilogue @@ -321,13 +315,12 @@ def kernel( # Wait for the accumulator buffer to be full acc_full = acc_consumer.wait_and_advance() - # TMEM -> RMEM -> GEMM # Sub-tiling for better instruction-level parallelism - for i in cutlass.range(cute.size(tCtC, mode=[2])): - cute.copy(tmem_tiled_copy, tCtC[None, None, i], tCrAcc) + for i in cutlass.range(cute.size(tDtC, mode=[2])): + cute.copy(tmem_tiled_copy, tDtC[None, None, i], tCrAcc) tCrC.store(tCrAcc.load().to(io_dtype)) - cute.autovec_copy(tCrC, tCgC[None, None, i]) + cute.autovec_copy(tCrC, tDgC[None, None, i]) acc_full.release() # Ensure used buffers are properly synchronized before producer exit. @@ -353,7 +346,7 @@ def host_function( io_dtype, acc_dtype, mma_inst_shape_mnk, - tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE, + tcgen05.CtaGroup.TWO, tcgen05.OperandSource.SMEM, tcgen05.OperandMajorMode.K, tcgen05.OperandMajorMode.K, @@ -381,16 +374,14 @@ def host_function( cta_layout_vmnk = cute.tiled_divide(cta_layout_mnk, (tiled_mma.thr_id,)) # Construct TMA load atoms - op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp( - tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE - ) + op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.TWO) a_tma_atom, a_tma_tensor = cute.nvgpu.make_tiled_tma_atom_A( op, a, a_smem_layout_one_stage, mma_tiler_mnk, tiled_mma, - cta_layout_vmnk.shape, + cta_layout_vmnk.shape, # take the layout and extract the shape internally ) b_tma_atom, b_tma_tensor = cute.nvgpu.make_tiled_tma_atom_B( op, @@ -403,8 +394,7 @@ def host_function( grid_shape = cute.round_up( cute.ceil_div( - (*c.layout.shape, 1), - (mma_tiler_mnk[0] // (2 if use_2cta_instrs else 1), *mma_tiler_mnk[1:]), + (*c.layout.shape, 1), (mma_tiler_mnk[0] // 2, *mma_tiler_mnk[1:]) ), cluster_shape_mnk, ) diff --git a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_4.py b/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_4.py index b5ad6d3b1..e2d47397d 100644 --- a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_4.py +++ b/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_4.py @@ -981,7 +981,7 @@ def run_dense_gemm( import cutlass.torch as cutlass_torch print("===================================================================") - print("Running Blackwell fp16 GEMM example 4 (with MIX CGA support):") + print("Running Blackwell fp16 GEMM example 4 (with MIX cluster size support):") print(f" mnk: {mnk}") print(f" tolerance: {tolerance}") print(f" Preferred cluster shape: {preferred_cluster_shape_mnk}") diff --git a/examples/python/CuTeDSL/blackwell/tutorial_gemm/nvfp4_gemm_0.py b/examples/python/CuTeDSL/blackwell/tutorial_gemm/nvfp4_gemm_0.py index f0dfeea15..d2631f1dd 100644 --- a/examples/python/CuTeDSL/blackwell/tutorial_gemm/nvfp4_gemm_0.py +++ b/examples/python/CuTeDSL/blackwell/tutorial_gemm/nvfp4_gemm_0.py @@ -500,7 +500,7 @@ class Sm100BlockScaledDenseGemmKernel: num_threads=self.threads_per_cta, ) tmem = utils.TmemAllocator( - storage.tmem_holding_buf, + storage.tmem_holding_buf.ptr, barrier_for_retrieve=tmem_alloc_barrier, ) tmem.allocate(self.num_tmem_alloc_cols) diff --git a/examples/python/CuTeDSL/blackwell/tutorial_gemm/nvfp4_gemm_1.py b/examples/python/CuTeDSL/blackwell/tutorial_gemm/nvfp4_gemm_1.py index 715de1b74..bebe522c4 100644 --- a/examples/python/CuTeDSL/blackwell/tutorial_gemm/nvfp4_gemm_1.py +++ b/examples/python/CuTeDSL/blackwell/tutorial_gemm/nvfp4_gemm_1.py @@ -436,7 +436,7 @@ class Sm100BlockScaledDenseGemmKernel: class SharedStorage: ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] - tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_dealloc_mbar: cutlass.Int64 tmem_holding_buf: cutlass.Int32 smem = utils.SmemAllocator() @@ -638,10 +638,10 @@ class Sm100BlockScaledDenseGemmKernel: num_threads=self.threads_per_cta, ) tmem = utils.TmemAllocator( - storage.tmem_holding_buf, + storage.tmem_holding_buf.ptr, barrier_for_retrieve=tmem_alloc_barrier, is_two_cta=cute.size(cta_layout_vmnk, mode=[0]) > 1, - two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr, ) tmem.allocate(self.num_tmem_alloc_cols) tmem.wait_for_alloc() diff --git a/examples/python/CuTeDSL/experimental/blackwell/dense_gemm_cute_pipeline.py b/examples/python/CuTeDSL/experimental/blackwell/dense_gemm_cute_pipeline.py index b4f7e5517..289662219 100755 --- a/examples/python/CuTeDSL/experimental/blackwell/dense_gemm_cute_pipeline.py +++ b/examples/python/CuTeDSL/experimental/blackwell/dense_gemm_cute_pipeline.py @@ -581,7 +581,7 @@ class PersistentDenseGemmKernel: acc_full_mbar_ptr: cute.struct.MemRange[ cutlass.Int64, self.num_acc_stage * 2 ] - tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_dealloc_mbar: cutlass.Int64 tmem_holding_buf: cutlass.Int32 smem = utils.SmemAllocator() @@ -632,11 +632,11 @@ class PersistentDenseGemmKernel: ) # Tensor memory dealloc barrier init tmem = utils.TmemAllocator( - storage.tmem_holding_buf, + storage.tmem_holding_buf.ptr, barrier_for_retrieve=tmem_alloc_barrier, allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta_instrs, - two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr, ) # Cluster arrive after barrier init diff --git a/examples/python/CuTeDSL/helpers/fmha_helpers.py b/examples/python/CuTeDSL/helpers/fmha_helpers.py index a465f577e..270438f41 100644 --- a/examples/python/CuTeDSL/helpers/fmha_helpers.py +++ b/examples/python/CuTeDSL/helpers/fmha_helpers.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/examples/python/CuTeDSL/hopper/fmha.py b/examples/python/CuTeDSL/hopper/fmha.py index aac359fe9..1892e848c 100644 --- a/examples/python/CuTeDSL/hopper/fmha.py +++ b/examples/python/CuTeDSL/hopper/fmha.py @@ -527,7 +527,6 @@ class HopperFusedMultiHeadAttentionForward: grid=grid, block=[self.threads_per_cta, 1, 1], cluster=self.cluster_shape_mnk, - smem=self.shared_storage.size_in_bytes(), stream=stream, min_blocks_per_mp=1, ) diff --git a/examples/python/CuTeDSL/hopper/grouped_gemm.py b/examples/python/CuTeDSL/hopper/grouped_gemm.py index 9430eba56..a0631be76 100644 --- a/examples/python/CuTeDSL/hopper/grouped_gemm.py +++ b/examples/python/CuTeDSL/hopper/grouped_gemm.py @@ -107,7 +107,7 @@ Constraints (same as dense_gemm_persistent.py plus): * Cluster shape M/N: power of 2, total <= 4 * Contiguous dim must be 16-byte aligned -Debug environment knobs: +Debug environment options: * `GROUPED_GEMM_FORCE_CUTE_COPY=1` Disable the non-mcast NVVM TMA load path and always use `cute.copy`. """ diff --git a/examples/python/CuTeDSL/jax/cutlass_call_basic.py b/examples/python/CuTeDSL/jax/cutlass_call_basic.py index e1860403d..e0232e27e 100644 --- a/examples/python/CuTeDSL/jax/cutlass_call_basic.py +++ b/examples/python/CuTeDSL/jax/cutlass_call_basic.py @@ -30,7 +30,6 @@ from functools import partial import jax import jax.numpy as jnp -import cutlass import cutlass.cute as cute import cutlass.jax as cjax import cuda.bindings.driver as cuda @@ -140,12 +139,12 @@ if __name__ == "__main__": def run_cutlass_kernel(a, b, x, y): call = cjax.cutlass_call( launch_jax_wrapper, - # Jax requires output shapes/dtype information for each output + # Describe the shape and dtype of each output buffer. output_shape_dtype=( jax.ShapeDtypeStruct(a.shape, a.dtype), jax.ShapeDtypeStruct(b.shape, a.dtype), ), - # Static jit arguments are passed via additional keyword arguments + # Static jit arguments are passed via additional keyword arguments. x=x, y=y, ) @@ -165,12 +164,11 @@ if __name__ == "__main__": # to the kernel. Alternatively you can wrap using another separate cute.jit # function. lambda stream, a, b, c, d, *, x, y: launch(a, b, x, y, c, d, stream), - # Jax requires output shapes/dtype information for each output output_shape_dtype=( jax.ShapeDtypeStruct(a.shape, a.dtype), jax.ShapeDtypeStruct(b.shape, a.dtype), ), - # Static jit arguments are passed via additional keyword arguments + # Static jit arguments are passed via additional keyword arguments. x=x, y=y, ) @@ -191,11 +189,12 @@ if __name__ == "__main__": jax.ShapeDtypeStruct(a.shape, a.dtype), jax.ShapeDtypeStruct(b.shape, a.dtype), ), - # By default cutlass_call will treat all tensors as dynamic shape. + # By default cutlass_call treats all tensors as dynamic shape. # Dynamic shapes are often expected for kernels so this default ensures # the broadest support. If you know that a kernel can accept fully static - # tensors then you can enable this flag to pass all tensors shapes and - # layouts known at compile time. + # tensors then you can enable this flag to compile all tensor shapes and + # layouts as constexpr values known at compile time. + # Individual tensors may opt out via .mark_layout_dynamic(). use_static_tensors=True, x=x, y=y, @@ -209,19 +208,15 @@ if __name__ == "__main__": @partial(jax.jit, static_argnums=[2, 3]) def run_cutlass_kernel_with_modes(a, b, x, y): + # input_spec and output_spec accept TensorSpec values to attach layout + # metadata to tensors. mode remaps the logical dimension order seen by + # the kernel. static=True compiles that tensor's layout as constexpr. call = cjax.cutlass_call( lambda stream, a, b, c, d, *, x, y: launch(a, b, x, y, c, d, stream), output_shape_dtype=( jax.ShapeDtypeStruct(a.shape, a.dtype), jax.ShapeDtypeStruct(b.shape, a.dtype), ), - # The modes of the layout for each tensor can be specified using the - # TensorSpec. By default modes will align with the physical layout - # but can be mapped to specific index position. If None is passed - # then the default mode is assumed for that tensor. - # - # Individual static/dynamic settings may also be applied. For example - # a specific tensor can be marked to have static shape. input_spec=( cjax.TensorSpec(mode=(1, 0, 2), static=True), cjax.TensorSpec(mode=(3, 1, 2, 0)), @@ -245,9 +240,8 @@ if __name__ == "__main__": jax.ShapeDtypeStruct(a.shape, a.dtype), jax.ShapeDtypeStruct(b.shape, b.dtype), ), - # Can specify the input tensors that are aliasing outputs of this call. - # To avoid allocating separate output buffers. This is useful for kernels - # that update a tensor. + # Map input indices to output indices so XLA can reuse the input + # buffers for the outputs, avoiding extra allocations. input_output_aliases={0: 0, 1: 1}, x=x, y=y, diff --git a/examples/python/CuTeDSL/jax/cutlass_call_export.py b/examples/python/CuTeDSL/jax/cutlass_call_export.py index 39a8baf1b..f0a494f2e 100644 --- a/examples/python/CuTeDSL/jax/cutlass_call_export.py +++ b/examples/python/CuTeDSL/jax/cutlass_call_export.py @@ -26,45 +26,45 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import pytest -from functools import partial -import argparse - -import cuda.bindings.driver as cuda - -import cutlass -import cutlass.cute as cute - -import jax -import jax.numpy as jnp -from jax import export - -from cutlass.jax import cutlass_call, get_export_disabled_safety_checks -from cutlass.jax.testing import create_tensor - """ Examples of using jax.export APIs with functions using cutlass_call. -This example demonstrates the use of jax.export with CuTe DSL kernel. It assumes -familiarity with CuTe DSL concepts such as layouts and dynamic shapes as well as -Jax's exporting and serialization features: +This example demonstrates three export modes: + +1. Concrete shapes -- shapes are fixed constants baked into the export. +2. Unconstrained symbolic shapes ("a, b") +3. Constrained symbolic shapes ("32*M, 16*N") + +The JAX function being exported is the same in all three cases; only the +shape specification passed to jax.export differs. + +It assumes familiarity with CuTe DSL concepts such as layouts and dynamic shapes +as well as JAX's exporting and serialization features: https://docs.jax.dev/en/latest/export/index.html#export To run this example: .. code-block:: bash - # Run with defaults - python examples/jax/cutlass_call_export.py + python examples/jax/cutlass_call_export.py --M 512 --N 256 - # Run with shape (1024, 512) - python examples/jax/cutlass_call_export.py --M 1024 --N 512 - - # Export with symbolic shapes. - python examples/jax/cutlass_call_export.py --export_symbolic """ +import argparse +import cuda.bindings.driver as cuda + +import cutlass.cute as cute + +import jax +import jax.numpy as jnp +from jax import export + +from cutlass.jax import cutlass_call, get_export_disabled_safety_checks, TensorSpec +from cutlass.jax.testing import create_tensor + + +# Simple element-wise addition kernel: gC[i,j] = gA[i,j] + gB[i,j] @cute.kernel def kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor): tidx, _, _ = cute.arch.thread_idx() @@ -84,9 +84,6 @@ def kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor): @cute.jit def launch(stream: cuda.CUstream, mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor): - print("mA: ", mA.layout) - print("mB: ", mB.layout) - print("mC: ", mC.layout) num_threads_per_block = 256 m, n = mA.shape kernel(mA, mB, mC).launch( @@ -96,63 +93,100 @@ def launch(stream: cuda.CUstream, mA: cute.Tensor, mB: cute.Tensor, mC: cute.Ten ) -def run_example(M, N, export_symbolic_shapes): +def _export_and_run(f, ref_f, input_shape_dtype, run_shapes): + """Export f, serialize/deserialize, then run on each shape in run_shapes. + + Both inputs (a, b) are assumed to share the same input_shape_dtype. + """ + print(f"Exporting with input signature: ({input_shape_dtype}, {input_shape_dtype})") + + # jax.export can be used to export a jit function containing cutlass_call. + # CUTLASS custom call targets are not on JAX's built-in stable custom-call + # allowlist, so we pass them via disabled_checks to suppress that safety check. + exported = jax.export.export(f, disabled_checks=get_export_disabled_safety_checks()) + traced = exported(input_shape_dtype, input_shape_dtype) + + blob = traced.serialize() + print(f"Serialized computation is {len(blob)} bytes.") + + rehydrated = export.deserialize(blob) + + key = jax.random.key(1123) + a_key, b_key = jax.random.split(key, 2) + for shape in run_shapes: + a = create_tensor(shape, dtype=jnp.float32, key=a_key) + b = create_tensor(shape, dtype=jnp.float32, key=b_key) + c = rehydrated.call(a, b) + assert jnp.allclose(c, ref_f(a, b)), f"Mismatch at shape {shape}" + print(f" shape {shape}: OK") + + +def run_example(M, N): + @jax.jit + def ref_f(a, b): + return jax.nn.sigmoid(a + b) + + # The same JAX function is used in all three examples below. The export + # mode is determined entirely by the shape spec passed to jax.export. @jax.jit def f(a, b): call = cutlass_call(launch, output_shape_dtype=a) return jax.nn.sigmoid(call(a, b)) + # ── 1. Concrete shapes ──────────────────────────────────────────────────── + # Shapes are fixed constants baked into the export. The deserialized + # computation only accepts exactly these dimensions at runtime. + print("\nConcrete shapes:") + + input_shape_dtype = jax.ShapeDtypeStruct((M, N), jnp.float32) + _export_and_run( + f, + ref_f, + input_shape_dtype, + run_shapes=[(M, N)], # concrete exports reject any other shape + ) + + # ── 2. Unconstrained symbolic shapes ───────────────────────────────────── + # Both dimensions are fully dynamic. The exported computation accepts any + # (M, N) at runtime without recompilation. + print("\nUnconstrained symbolic shapes:") + + a_sym, b_sym = export.symbolic_shape("a, b") + input_shape_dtype = jax.ShapeDtypeStruct((a_sym, b_sym), jnp.float32) + _export_and_run( + f, + ref_f, + input_shape_dtype, + run_shapes=[(M, N), (M * 2, N * 4), (M * 4, N * 4)], + ) + + # ── 3. Constrained symbolic shapes (divisibility) ───────────────────────── + # Shapes are declared as multiples of a tile size via TensorSpec.divisibility. + # The symbolic expression "32*M, 16*N" tells jax.export that dim 0 is always + # a multiple of 32 and dim 1 is always a multiple of 16. This lets the + # compiler generate more efficient code (e.g. no remainder handling). + # Runtime shapes must satisfy these divisibility constraints. + print("\nConstrained symbolic shapes:") + @jax.jit - def ref_f(a, b): - return jax.nn.sigmoid(a + b) + def f_divisible(a, b): + spec = TensorSpec(divisibility=(32, 16)) + call = cutlass_call( + launch, + output_shape_dtype=a, + input_spec=(spec, spec), + output_spec=spec, + ) + return jax.nn.sigmoid(call(a, b)) - # Symbolic or partially shapes are supported by cutlass_call and cute.Tensor - # This allows export of functions calling Cut eDSL kernels w/o having to re-compile - # the kernel for each new shape. - if export_symbolic_shapes: - a, b = export.symbolic_shape("a, b") - export_shape_dtype = jax.ShapeDtypeStruct((a, b), jnp.float32) - else: - export_shape_dtype = jax.ShapeDtypeStruct((M, N), jnp.float32) - - print("Exporting with input signature: ") - print(f"({export_shape_dtype}, {export_shape_dtype})") - - # jax.export can be used to export a jit function containing cutlass_call. - # The function get_export_disabled_safety_checks() returns a list of custom - # call targets that are used by cutlass_call not part of Jax's built-in - # list of stable custom calls. - exported = jax.export.export(f, disabled_checks=get_export_disabled_safety_checks()) - traced = exported(export_shape_dtype, export_shape_dtype) - - # Serialize the computation to a byte blob. - blob = traced.serialize() - print(f"Serialized computation is {len(blob)} bytes.") - - # Deserialize and run - rehydrated = export.deserialize(blob) - - key = jax.random.key(1123) - a_key, b_key = jax.random.split(key, 2) - - a = create_tensor((M, N), dtype=jnp.float32, key=a_key) - b = create_tensor((M, N), dtype=jnp.float32, key=b_key) - c = rehydrated.call(a, b) - assert jnp.allclose(c, ref_f(a, b)) - - # If the computation was exported with dynamic shapes then we can also - # call it with different shapes. The kernel will not be re-compiled - # even though the shapes are changing. - if export_symbolic_shapes: - a = create_tensor((M * 2, N * 4), dtype=jnp.float32, key=a_key) - b = create_tensor((M * 2, N * 4), dtype=jnp.float32, key=b_key) - c = rehydrated.call(a, b) - assert jnp.allclose(c, ref_f(a, b)) - - a = create_tensor((M * 4, N * 4), dtype=jnp.float32, key=a_key) - b = create_tensor((M * 4, N * 4), dtype=jnp.float32, key=b_key) - c = rehydrated.call(a, b) - assert jnp.allclose(c, ref_f(a, b)) + m_sym, n_sym = export.symbolic_shape("32*M, 16*N") + input_shape_dtype = jax.ShapeDtypeStruct((m_sym, n_sym), jnp.float32) + _export_and_run( + f_divisible, + ref_f, + input_shape_dtype, + run_shapes=[(M, N), (M * 2, N * 2), (M * 4, N * 4)], + ) if __name__ == "__main__": @@ -161,8 +195,7 @@ if __name__ == "__main__": ) parser.add_argument("--M", default=512, type=int) parser.add_argument("--N", default=256, type=int) - parser.add_argument("--export_symbolic", action="store_true") args = parser.parse_args() - run_example(args.M, args.N, args.export_symbolic) + run_example(args.M, args.N) print("PASS") diff --git a/examples/python/CuTeDSL/jax/cutlass_call_sharding.py b/examples/python/CuTeDSL/jax/cutlass_call_sharding.py index e40687ff7..5f92cc55a 100644 --- a/examples/python/CuTeDSL/jax/cutlass_call_sharding.py +++ b/examples/python/CuTeDSL/jax/cutlass_call_sharding.py @@ -27,14 +27,12 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from functools import partial -import argparse import jax import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec as P, AxisType from jax.experimental.custom_partitioning import custom_partitioning -import cutlass import cutlass.cute as cute import cutlass.jax as cjax from cutlass.jax.testing import create_tensor diff --git a/examples/python/CuTeDSL/jax/elementwise_apply_example.py b/examples/python/CuTeDSL/jax/elementwise_apply_example.py index c958d09ce..d457d5a3a 100644 --- a/examples/python/CuTeDSL/jax/elementwise_apply_example.py +++ b/examples/python/CuTeDSL/jax/elementwise_apply_example.py @@ -30,7 +30,7 @@ import argparse import operator from functools import partial -from typing import List, Type +from typing import List import cuda.bindings.driver as cuda import cutlass diff --git a/examples/python/CuTeDSL/notebooks/async_pipeline.ipynb b/examples/python/CuTeDSL/notebooks/async_pipeline.ipynb index 3e8f2a979..1a98e3dc7 100644 --- a/examples/python/CuTeDSL/notebooks/async_pipeline.ipynb +++ b/examples/python/CuTeDSL/notebooks/async_pipeline.ipynb @@ -78,7 +78,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -114,7 +114,7 @@ " ]\n", "\n", " synced_producer_consumer(SharedStorage, res).launch(\n", - " grid=(1, 1, 1), block=(64, 1, 1), smem=SharedStorage.size_in_bytes()\n", + " grid=(1, 1, 1), block=(64, 1, 1)\n", " )\n", "\n", "\n", @@ -455,7 +455,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -527,7 +527,7 @@ " ]\n", "\n", " async_pipeline_staged_kernel(SharedStorage, res, staging).launch(\n", - " grid=(1, 1, 1), block=(64, 1, 1), smem=SharedStorage.size_in_bytes()\n", + " grid=(1, 1, 1), block=(64, 1, 1)\n", " )\n", "\n", "\n", diff --git a/include/cute/arch/copy_sm100_tma.hpp b/include/cute/arch/copy_sm100_tma.hpp index 178995d27..82d66827a 100644 --- a/include/cute/arch/copy_sm100_tma.hpp +++ b/include/cute/arch/copy_sm100_tma.hpp @@ -663,4 +663,105 @@ struct SM100_TMA_2SM_LOAD_IM2COL_MULTICAST //////////////////////////////////////////////////////////////////////////////////////////////////// +struct SM100_TMA_LOAD_2D_GATHER4 +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1_i0, int32_t const& crd1_i1, int32_t const& crd1_i2, int32_t const& crd1_i3) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.2d.shared::cluster.global.tile::gather4.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1_i0), "r"(crd1_i1), "r"(crd1_i2), "r"(crd1_i3), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1_i0, int32_t const& crd1_i1, int32_t const& crd1_i2, int32_t const& crd1_i3) + { + #if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.2d.L2.global.tile::gather4" + " [%0, {%1, %2, %3, %4, %5}];" + : + : "l"(gmem_int_desc), + "r"(crd0), "r"(crd1_i0), "r"(crd1_i1), "r"(crd1_i2), "r"(crd1_i3) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); + #endif + } + }; +}; + + +struct SM100_TMA_LOAD_MULTICAST_2D_GATHER4 +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1_i0, int32_t const& crd1_i1, int32_t const& crd1_i2, int32_t const& crd1_i3) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.2d.shared::cluster.global.tile::gather4.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5, %6, %7, %8}], [%2], %3, %9;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "h"(multicast_mask), + "r"(crd0), "r"(crd1_i0), "r"(crd1_i1), "r"(crd1_i2), "r"(crd1_i3), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } + + using PREFETCH = SM100_TMA_LOAD_2D_GATHER4::PREFETCH; +}; + +struct SM100_TMA_STORE_2D_SCATTER4 +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1_i0, int32_t const& crd1_i1, int32_t const& crd1_i2, int32_t const& crd1_i3) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.2d.global.shared::cta.tile::scatter4.bulk_group [%0, {%2, %3, %4, %5, %6}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1_i0), "r"(crd1_i1), "r"(crd1_i2), "r"(crd1_i3) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + } // end namespace cute diff --git a/include/cute/atom/copy_traits_sm100_tma.hpp b/include/cute/atom/copy_traits_sm100_tma.hpp index f62971f5e..07639ba2a 100644 --- a/include/cute/atom/copy_traits_sm100_tma.hpp +++ b/include/cute/atom/copy_traits_sm100_tma.hpp @@ -239,6 +239,331 @@ struct Copy_Traits } }; +////////////////////////////////////////////////////////////////////////////// +////////////////////////// TMA_LOAD_GATHER /////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +// Utility for unpacking TMA_LOAD arguments into a CopyOp +template +struct TMA_LOAD_GATHER_Unpack +{ + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_smem::value, "SM100_TMA_LOAD_2D_GATHER4 requires the destination be shared memory."); + + auto [src_crd, src_idx] = unzip_tensor(src); + + auto src_coord = src_crd.data().coord_; + static_assert(rank(src_coord) == 2, "SM100_TMA_LOAD_2D_GATHER4 requires 2D tensors"); + + Tensor idx = filter(src_idx); + static_assert(size(idx) == 4, "SM100_TMA_LOAD_2D_GATHER4 requires 4 indices"); + + auto coord = make_tuple(get<0>(src_coord), idx(0), idx(1), idx(2), idx(3)); + void* dst_ptr = cute::raw_pointer_cast(dst.data()); +#if 0 + auto [c0,c1,c2,c3,c4] = coord; + printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", + threadIdx.x, threadIdx.y, threadIdx.z, + blockIdx.x, blockIdx.y, blockIdx.z, + int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), dst_ptr); +#endif + return detail::explode_tuple(detail::CallCOPY{}, + traits.opargs_, tuple_seq{}, + make_tuple(dst_ptr), seq<0>{}, + coord, make_seq<5>{}); + } +}; + +struct SM100_TMA_LOAD_2D_GATHER4_OP : SM100_TMA_LOAD_2D_GATHER4 {}; + +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM100_TMA_LOAD arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Construct an executable SM90_TMA_LOAD with tma_mbar + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + uint64_t& tma_mbar, + [[maybe_unused]] uint16_t const& multicast_mask = 0, + TMA::CacheHintSm100 const& cache_hint = TMA::CacheHintSm100::EVICT_NORMAL) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {&tma_desc_, &tma_mbar, static_cast(cache_hint)}; + } + + // Construct an executable SM90_TMA_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + TmaDescriptor const* new_tma_desc, + uint64_t& tma_mbar, + [[maybe_unused]] uint16_t const& multicast_mask = 0, + TMA::CacheHintSm100 const& cache_hint = TMA::CacheHintSm100::EVICT_NORMAL) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {new_tma_desc, &tma_mbar, static_cast(cache_hint)}; + } + + // Generate the TMA coord tensor + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_coord_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + // Don't try to execute a copy with SM100_TMA_LOAD_2D_GATHER4 before calling .with() + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +template +struct Copy_Traits + : TMA_LOAD_GATHER_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + tuple< + TmaDescriptor const*, + uint64_t*, // smem mbarrier + uint64_t // cache hint + > const opargs_; + + CUTE_HOST_DEVICE + Copy_Traits(TmaDescriptor const* desc, uint64_t* mbar, uint64_t cache) + : opargs_(desc, mbar, cache) {} +}; + +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + tuple const opargs_; + + // Construct with any other Traits' TMA Desc + template + CUTE_HOST_DEVICE + Copy_Traits(Copy_Traits const& traits) + : opargs_({&traits.tma_desc_}) {} + + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + auto [src_crd, src_idx] = unzip_tensor(src); + + auto src_coord = src_crd.data().coord_; + static_assert(rank(src_coord) == 2, "SM100_TMA_LOAD_2D_GATHER4 requires 2D tensors"); + + Tensor idx = filter(src_idx); + static_assert(size(idx) == 4, "SM100_TMA_LOAD_2D_GATHER4 requires 4 indices"); + + auto coord = make_tuple(get<0>(src_coord), idx(0), idx(1), idx(2), idx(3)); + + return detail::explode_tuple(detail::CallCOPY{}, + traits.opargs_, tuple_seq{}, + coord, make_seq<5>{}); + } +}; + +struct SM100_TMA_LOAD_MULTICAST_2D_GATHER4_OP : SM100_TMA_LOAD_MULTICAST_2D_GATHER4 {}; + +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM100_TMA_LOAD arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Construct an executable SM90_TMA_LOAD with tma_mbar + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + uint64_t& tma_mbar, + [[maybe_unused]] uint16_t const& multicast_mask = 0, + TMA::CacheHintSm100 const& cache_hint = TMA::CacheHintSm100::EVICT_NORMAL) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {&tma_desc_, &tma_mbar, multicast_mask, static_cast(cache_hint)}; + } + + // Construct an executable SM90_TMA_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + TmaDescriptor const* new_tma_desc, + uint64_t& tma_mbar, + [[maybe_unused]] uint16_t const& multicast_mask = 0, + TMA::CacheHintSm100 const& cache_hint = TMA::CacheHintSm100::EVICT_NORMAL) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {new_tma_desc, &tma_mbar, multicast_mask, static_cast(cache_hint)}; + } + + // Generate the TMA coord tensor + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_coord_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + // Don't try to execute a copy with SM100_TMA_LOAD_2D_GATHER4 before calling .with() + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +template +struct Copy_Traits + : TMA_LOAD_GATHER_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + tuple< + TmaDescriptor const*, + uint64_t*, // smem mbarrier + uint16_t, // multicast mask + uint64_t // cache hint + > const opargs_; + + CUTE_HOST_DEVICE + Copy_Traits(TmaDescriptor const* desc, uint64_t* mbar, uint16_t mask, uint64_t cache) + : opargs_(desc, mbar, mask, cache) {} +}; + +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Generate the TMA coord tensor + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_coord_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + + auto [dsc_crd, dsc_idx] = unzip_tensor(dst); + + auto dsc_coord = dsc_crd.data().coord_; + static_assert(rank(dsc_coord) == 2, "SM100_TMA_STORE_2D_SCATTER4 requires 2D tensors"); + + Tensor idx = filter(dsc_idx); + static_assert(size(idx) == 4, "SM100_TMA_STORE_2D_SCATTER4 requires 4 indices"); + + auto coord = make_tuple(get<0>(dsc_coord), idx(0), idx(1), idx(2), idx(3)); + void* src_ptr = cute::raw_pointer_cast(src.data()); +#if 0 + auto [c0,c1,c2,c3,c4] = coord; + printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", + threadIdx.x, threadIdx.y, threadIdx.z, + blockIdx.x, blockIdx.y, blockIdx.z, + int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr); +#endif + + return detail::explode_tuple(detail::CallCOPY{}, + make_tuple(&traits.tma_desc_), seq<0>{}, + make_tuple(src_ptr), seq<0>{}, + coord, make_seq<5>{}); + } +}; //////////////////////////////////// // Make TMA /////////////////////////////////// @@ -428,12 +753,14 @@ make_tma_atom_A_sm100(CopyOp const& copy_op, // The size of the multicasting auto num_multicast = [&](){ if constexpr (is_same_v || - is_same_v) { + is_same_v || + is_same_v) { return size<2>(cluster_shape); // VMNK: Use only the N-CTAs in the Multicast } else if constexpr (is_same_v || is_same_v || - is_same_v) { + is_same_v || + is_same_v) { return Int<1>{}; // VMNK: Use no CTAs in Non-Multicast } else { static_assert(dependent_false, "Unsupported TMA"); @@ -479,12 +806,14 @@ make_tma_atom_B_sm100(CopyOp const& copy_op, // The size of the multicasting auto num_multicast = [&](){ if constexpr (is_same_v || - is_same_v) { + is_same_v || + is_same_v) { return size<1>(cluster_shape); // VMNK: Use only the M-CTAs in the Multicast } else if constexpr (is_same_v || is_same_v || - is_same_v) { + is_same_v || + is_same_v) { return Int<1>{}; // VMNK: Use no CTAs in Non-Multicast } else { static_assert(dependent_false, "Unsupported TMA"); diff --git a/include/cute/atom/copy_traits_sm90_tma.hpp b/include/cute/atom/copy_traits_sm90_tma.hpp index 78ac598fb..93c297ea6 100644 --- a/include/cute/atom/copy_traits_sm90_tma.hpp +++ b/include/cute/atom/copy_traits_sm90_tma.hpp @@ -42,6 +42,7 @@ #include +#include #include namespace cute @@ -1149,13 +1150,27 @@ make_tma_copy_atom(CopyOp, auto smem_layout = get_nonswizzle_portion(slayout); auto tma_gbasis = detail::construct_tma_gbasis(gtensor, smem_layout, cta_v_map); - + auto tma_gbasis_tuple = conditional_return + ||is_same_v + ||is_same_v>( + [](auto tma_gbasis) constexpr { + static_assert(rank_v == 2, "TMA Gather/Scatter only supports 2D tensors"); + auto tma_gbasis_g4 = tma_gbasis.compose(make_identity_layout(make_shape(shape<0>(tma_gbasis), _1{}))); + auto tma_gbasis_g4_size = size(tma_gbasis_g4) * _4{}; + return make_tuple(tma_gbasis_g4, tma_gbasis_g4_size); + }, + [](auto tma_gbasis) constexpr { + auto tma_gbasis_size = size(tma_gbasis); + return make_tuple(tma_gbasis, tma_gbasis_size); + })(tma_gbasis); + auto _tma_gbasis = get<0>(tma_gbasis_tuple); + auto _tma_gbasis_size = get<1>(tma_gbasis_tuple); // // Construct the TMA Desc and the strides of the TMA Tensor // auto [tma_desc, aux_params] = detail::make_tma_copy_desc(gtensor, - tma_gbasis, + _tma_gbasis, smem_swizzle, num_multicast); @@ -1163,7 +1178,7 @@ make_tma_copy_atom(CopyOp, // Construct the Copy_Traits // - constexpr int num_bits_per_tma = size(tma_gbasis) * sizeof_bits_v; + constexpr int num_bits_per_tma = _tma_gbasis_size * sizeof_bits_v; using Traits = Copy_Traits, decltype(aux_params)>; using Atom = Copy_Atom; @@ -1397,17 +1412,16 @@ template + class... GTensors, + __CUTE_REQUIRES(conjunction_v...>)> CUTE_DEVICE auto tma_partition(Copy_Atom const& copy_atom, CtaCoord const& cta_coord, Layout const& cta_layout, // T: CTA coord -> logical multicast id Tensor const& stensor, // SMEM Tensor (TMATile, Rest...) - Tensor const& gtensor) // GMEM Tensor (TMATile, Rest...) + GTensors const&... gtensors) // GMEM Tensor (TMATile, Rest...) { - CUTE_STATIC_ASSERT_V(size<0>(stensor) == size<0>(gtensor)); - // Invert the smem to get the largest contiguous vector in the smem layout Layout inv_smem_layout = right_inverse(get_nonswizzle_portion(layout<0>(stensor))); // Scale that up to cover all of the smem_coords @@ -1417,22 +1431,24 @@ tma_partition(Copy_Atom const& copy_atom, Layout tma_layout_v = make_layout(Int::NumValSrc>{}); auto layout_V = make_tile(logical_divide(layout_v, tma_layout_v)); - // Append with _ until we cover all Rest... modes - auto glayout_V = append(layout_V, _); - auto slayout_V = append(layout_V, _); - // Transform tile mode and coalesce - Tensor gtensor_v = coalesce(gtensor.compose(glayout_V), Shape>{}); // ((TMA,TMA_Iter), Rest...) - Tensor stensor_v = coalesce(stensor.compose(slayout_V), Shape>{}); // ((TMA,TMA_Iter), Rest...) - // Offset inside the TMA-mode for the multicast auto multicast_offset = cta_layout(cta_coord) * (size(tma_layout_v) / cosize(cta_layout)); auto multicast_coord = make_coord(make_coord(multicast_offset, Int<0>{})); - auto gcoord = append(multicast_coord, Int<0>{}); - auto scoord = append(multicast_coord, Int<0>{}); - Tensor gresult = domain_offset(gcoord, gtensor_v); - Tensor sresult = domain_offset(scoord, stensor_v); + // Existing convention is to return stensor last + return cute::transform(make_tuple(gtensors..., stensor), [&](auto && tensor) { + auto R = rank(tensor); + CUTE_STATIC_ASSERT_V(size<0>(stensor) == size<0>(tensor)); - return cute::make_tuple(gresult, sresult); + // Append with _ until we cover all Rest... modes + auto tlayout_V = append(layout_V, _); + + // Transform tile mode and coalesce + Tensor tensor_v = coalesce(tensor.compose(tlayout_V), Shape>{}); // ((TMA,TMA_Iter), Rest...) + + // Offset inside the TMA-mode for the multicast + auto coord = append(multicast_coord, Int<0>{}); + return domain_offset(coord, tensor_v); + }); } // Explicit defaults for cta_coord and cta_layout diff --git a/include/cute/tensor_zip.hpp b/include/cute/tensor_zip.hpp index b23a7c619..b99b641d4 100644 --- a/include/cute/tensor_zip.hpp +++ b/include/cute/tensor_zip.hpp @@ -72,9 +72,12 @@ struct ZipIterator template CUTE_HOST_DEVICE constexpr - ZipIterator operator+(cute::tuple const& idxs) const { + auto operator+(cute::tuple const& idxs) const { static_assert(sizeof...(Index) == sizeof...(Iters), "Expect same number of offsets as iterators."); - return cute::transform(iters_, idxs, [](auto&& iter, auto&& idx) { return iter + idx; }); + return cute::transform_apply(iters_, idxs, + [](auto&& iter, auto&& idx) { return iter + idx; }, + [](auto... iter) { return ZipIterator(iter...); } + ); } template @@ -149,6 +152,13 @@ struct ZipLayout template struct is_layout> : true_type {}; +template +struct is_zip_layout : false_type {}; + +template +struct is_zip_layout> : true_type {}; + + // // make_zip_tensor and unzip_tensor // @@ -191,6 +201,23 @@ size(ZipLayout const& layouts) return size(get<0>(layouts.layouts_)); } + +template +CUTE_HOST_DEVICE constexpr +auto +get(ZipLayout const& layouts) +{ + return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return get(t); })); +} + +template +CUTE_HOST_DEVICE constexpr +auto +layout(ZipLayout const& layouts) +{ + return get(layouts); +} + // // Manipulation // @@ -243,4 +270,45 @@ slice_and_offset(Coord const& c, ZipLayout const& layouts) return cute::make_tuple(ZipLayout(get<0>(result)), get<1>(result)); } +template +CUTE_HOST_DEVICE constexpr +auto +group(ZipLayout const& layouts) +{ + return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return group(t); })); +} + +template +CUTE_HOST_DEVICE constexpr +auto +shape(ZipLayout const& layouts) { + return shape(get<0>(layouts.layouts_)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +coshape(ZipLayout const& layouts) { + return cute::transform(layouts.layouts_, [&](auto t){ return coshape(t); }); +} + +template +CUTE_HOST_DEVICE constexpr +auto +cosize(ZipLayout const& layouts) +{ + return size(coshape(layouts)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +nullspace(ZipLayout const& layouts) { + return cute::fold(layouts.layouts_, make_layout(size(layouts)), + [](auto null, auto layout) { + return composition(null, nullspace(composition(layout, null))); + }); +} + + } // end namespace cute diff --git a/include/cutlass/conv/collective/builders/sm100_common.inl b/include/cutlass/conv/collective/builders/sm100_common.inl index 4256898b6..c896f2949 100644 --- a/include/cutlass/conv/collective/builders/sm100_common.inl +++ b/include/cutlass/conv/collective/builders/sm100_common.inl @@ -186,6 +186,18 @@ sm100_make_tiled_mma() { } } +template +constexpr int sm100_reduced_smem_capacity_bytes() { + if constexpr (cute::is_same_v) { + return cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + } + else if constexpr (cute::is_same_v) { + return cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + } + else { + static_assert(cutlass::detail::dependent_false, "Invalid ArchTag, only Sm10x are supported."); + } +} ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::conv::collective::detail diff --git a/include/cutlass/conv/collective/builders/sm100_umma_builder.inl b/include/cutlass/conv/collective/builders/sm100_umma_builder.inl index 475f1be40..844f838ea 100644 --- a/include/cutlass/conv/collective/builders/sm100_umma_builder.inl +++ b/include/cutlass/conv/collective/builders/sm100_umma_builder.inl @@ -44,6 +44,7 @@ using namespace cute; ///////////////////////////////////////////////////////////////////////////////////////////////// template < + class ArchTag, conv::Operator ConvOp, class ElementA, class GmemLayoutA, @@ -58,7 +59,7 @@ template < class KernelScheduleType > struct CollectiveBuilder< - arch::Sm100, + ArchTag, arch::OpClassTensorOp, ConvOp, ElementA, @@ -73,6 +74,9 @@ struct CollectiveBuilder< StageCountType, KernelScheduleType, cute::enable_if_t< + (cute::is_same_v + || cute::is_same_v + ) && (cute::is_same_v || cute::is_same_v || cute::is_same_v || @@ -191,12 +195,12 @@ private: CLCResponseStorage + TmemBasePtrsStorage); // Reduce SMEM capacity available for buffers considering barrier allocations. - static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + static constexpr int ReducedSmemCapacityBytes = detail::sm100_reduced_smem_capacity_bytes(); using SmemTileShape = cute::Shape; static constexpr int PipelineStages = detail::compute_stage_count_or_override< - Sm100ReducedSmemCapacityBytes, ElementAMma, ElementBMma, SmemTileShape>(StageCountType{}); + ReducedSmemCapacityBytes, ElementAMma, ElementBMma, SmemTileShape>(StageCountType{}); constexpr static int NumSpatialDimensions = detail::gmem_layout_tags_to_spatial_dims(); @@ -206,7 +210,8 @@ private: NumSpatialDimensions, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape_MNK>; + ClusterShape_MNK, + ArchTag>; public: using CollectiveOp = cutlass::conv::collective::CollectiveConv< diff --git a/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp b/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp index 49dad6729..75dd51d76 100644 --- a/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp +++ b/include/cutlass/conv/collective/sm100_implicit_gemm_umma_warpspecialized.hpp @@ -65,6 +65,7 @@ template < int NumSpatialDims, int SchedulerPipelineStageCount, int AccumulatorPipelineStageCount, + class ArchTag_, class ClusterShape, // Static cluster shape or dynamic (int, int, _1) class TileShapeMNKL_, // (MmaAtomShapeM, MmaAtomShapeN, TileK, optional: TileL) class ElementA_, @@ -79,7 +80,8 @@ struct CollectiveConv< NumSpatialDims, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>, + ClusterShape, + ArchTag_>, TileShapeMNKL_, ElementA_, ElementB_, @@ -96,7 +98,8 @@ struct CollectiveConv< NumSpatialDims, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>; + ClusterShape, + ArchTag_>; using TileShape = decltype(cute::take<0,3>(TileShapeMNKL_{})); // (MmaAtomShapeM, MmaAtomShapeN, TileK) using ElementA = ElementA_; using ElementB = ElementB_; diff --git a/include/cutlass/conv/device/conv_universal_adapter.hpp b/include/cutlass/conv/device/conv_universal_adapter.hpp index 90016d028..80a1bcbe7 100644 --- a/include/cutlass/conv/device/conv_universal_adapter.hpp +++ b/include/cutlass/conv/device/conv_universal_adapter.hpp @@ -351,7 +351,8 @@ public: } else { if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 100 || - ConvKernel::ArchTag::kMinComputeCapability == 101) { + ConvKernel::ArchTag::kMinComputeCapability == 101 + ) { launch_result = ClusterLauncher::launch_with_fallback_cluster( grid, cluster, diff --git a/include/cutlass/conv/dispatch_policy.hpp b/include/cutlass/conv/dispatch_policy.hpp index 4eb1bafc3..821d104d3 100644 --- a/include/cutlass/conv/dispatch_policy.hpp +++ b/include/cutlass/conv/dispatch_policy.hpp @@ -116,14 +116,15 @@ template< int NumSpatialDimensions_, int SchedulerPipelineStageCount_, int AccumulatorPipelineStageCount_, - class ClusterShape_ = cute::Shape,cute::C<1>,cute::C<1>> + class ClusterShape_ = cute::Shape,cute::C<1>,cute::C<1>>, + class ArchTag_ = arch::Sm100 > struct MainloopSm100TmaUmmaWarpSpecializedImplicitGemm { static constexpr int Stages = Stages_; static constexpr int NumSpatialDimensions = NumSpatialDimensions_; static constexpr Operator ConvOp = ConvOp_; using ClusterShape = ClusterShape_; - using ArchTag = arch::Sm100; + using ArchTag = ArchTag_; using Schedule = KernelScheduleImplicitTmaWarpSpecializedSm100; static_assert(NumSpatialDimensions >= 1); diff --git a/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp b/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp index 494ffe7af..aae691260 100644 --- a/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp @@ -188,7 +188,6 @@ public: }; static constexpr int SharedStorageSize = sizeof(SharedStorage); - static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // Host facing host arguments struct Arguments { @@ -413,7 +412,7 @@ public: // Kernel level shared memory storage SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); // In a warp specialized kernel, collectives expose data movement and compute operations separately CollectiveMainloop collective_mainloop(params.mainloop, cluster_shape, cta_rank_in_cluster); CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); diff --git a/include/cutlass/detail/layout.hpp b/include/cutlass/detail/layout.hpp index 3d0722094..f3eada3bb 100644 --- a/include/cutlass/detail/layout.hpp +++ b/include/cutlass/detail/layout.hpp @@ -310,6 +310,9 @@ constexpr bool is_tma_copy_engine() { || cute::is_base_of_v || cute::is_base_of_v || cute::is_base_of_v + || cute::is_base_of_v + || cute::is_base_of_v + || cute::is_base_of_v ) { return true; } diff --git a/include/cutlass/epilogue/collective/sm100_epilogue_array_planar_complex_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm100_epilogue_array_planar_complex_tma_warpspecialized.hpp index 8c0626a4a..3d4202c82 100644 --- a/include/cutlass/epilogue/collective/sm100_epilogue_array_planar_complex_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm100_epilogue_array_planar_complex_tma_warpspecialized.hpp @@ -138,10 +138,12 @@ private: constexpr static bool is_m_major_C = detail::is_m_major(); constexpr static bool is_m_major_D = detail::is_m_major(); + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + // Multiple buffer the TMA descriptors for each SM so that we can update them asynchronously. // This should be larger than the total number of TMA requests inflight (from update to issued to returned). // This can be calculated by SchedulerStages + max(TmaStages) + 2 (for consumer and producer in-flight accessies). - constexpr static uint32_t NumTmaDescriptorsPerSm = NumMaxSchedulerPipelineStageCount + std::max(StagesC, (ReuseSmemC ? StagesC : StagesD)) + 2; + constexpr static uint32_t NumTmaDescriptorsPerSm = IsGroupedGemmKernel ? (NumMaxSchedulerPipelineStageCount + std::max(StagesC, (ReuseSmemC ? StagesC : StagesD)) + 2) : 1; using SmemLayoutC = decltype(tile_to_shape( SmemLayoutAtomC{}, diff --git a/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp index 9ad99b13b..1903443d2 100644 --- a/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm100_epilogue_array_tma_warpspecialized.hpp @@ -181,11 +181,6 @@ private: // TMA store delay only benefits with loop unrolling constexpr static bool DelayTmaStore = DelayTmaStore_ and UnrollEpiLoop; - // Multiple buffer the TMA descriptors for each SM so that we can update them asynchronously. - // This should be larger than the total number of TMA requests inflight (from update to issued to returned). - // This can be calculated by SchedulerStages + max(TmaStages) + 2 (for consumer and producer in-flight accessies). - constexpr static uint32_t NumTmaDescriptorsPerSm = NumMaxSchedulerPipelineStageCount + std::max(StagesC, (ReuseSmemC ? StagesC : StagesD)) + 2; - struct CollectiveStorageWithC { alignas(SmemAlignmentC) ArrayEngine> smem_C; alignas(SmemAlignmentD) ArrayEngine> smem_D; @@ -241,6 +236,11 @@ public: static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + // Multiple buffer the TMA descriptors for each SM so that we can update them asynchronously. + // This should be larger than the total number of TMA requests inflight (from update to issued to returned). + // This can be calculated by SchedulerStages + max(TmaStages) + 2 (for consumer and producer in-flight accessies). + constexpr static uint32_t NumTmaDescriptorsPerSm = IsGroupedGemmKernel ? (NumMaxSchedulerPipelineStageCount + std::max(StagesC, (ReuseSmemC ? StagesC : StagesD)) + 2) : 1; + // Host side epilogue arguments struct Arguments { typename FusionCallbacks::Arguments thread{}; diff --git a/include/cutlass/gemm/collective/builders/sm100_9xBF16_interleaved_complex_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_9xBF16_interleaved_complex_umma_builder.inl index 5cdbdd7e9..6b82ab9f4 100644 --- a/include/cutlass/gemm/collective/builders/sm100_9xBF16_interleaved_complex_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_9xBF16_interleaved_complex_umma_builder.inl @@ -185,8 +185,7 @@ struct CollectiveBuilder< // Reduce SMEM capacity available for buffers considering extra B smem and barrier smem allocations - static constexpr int ReducedSmemCapacityBytes = - cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + static constexpr int ReducedSmemCapacityBytes = detail::sm100_reduced_smem_capacity_bytes(); static constexpr auto stage_info = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_fast_fp32< ReducedSmemCapacityBytes, CtaTileShape_MNK, TiledMma, BuilderScheduleTag, UmmaMajorACompute, NumComplexComponents, NumComputeMtxs @@ -209,7 +208,8 @@ struct CollectiveBuilder< ScalingFactor, AccPromotionInterval, ClusterShape_MNK, - AccumulatorCopyAtom>, + AccumulatorCopyAtom, + ArchTag>, cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedFastF32< Load2TransformPipelineStageCount, Transform2MmaPipelineStageCount, @@ -219,8 +219,9 @@ struct CollectiveBuilder< ScalingFactor, AccPromotionInterval, ClusterShape_MNK, - AccumulatorCopyAtom> - >; + AccumulatorCopyAtom, + ArchTag> + >; using CollectiveOp = cutlass::gemm::collective::CollectiveMma< DispatchPolicy, TileShape_MNK, diff --git a/include/cutlass/gemm/collective/builders/sm100_9xBF16_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_9xBF16_umma_builder.inl index 4b3454576..3b5bb109c 100644 --- a/include/cutlass/gemm/collective/builders/sm100_9xBF16_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_9xBF16_umma_builder.inl @@ -253,8 +253,7 @@ struct CollectiveBuilder< // Reduce SMEM capacity available for buffers considering extra B smem and barrier smem allocations - static constexpr int ReducedSmemCapacityBytes = - cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + static constexpr int ReducedSmemCapacityBytes = detail::sm100_reduced_smem_capacity_bytes(); static constexpr auto stage_info = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_fast_fp32< ReducedSmemCapacityBytes, CtaTileShape_MNK, TiledMma, BuilderScheduleTag, UmmaMajorACompute, /*Cmplx=*/ 1, /*Mtxs=*/ NumComputeMtxs @@ -276,7 +275,8 @@ struct CollectiveBuilder< ScalingFactor, AccPromotionInterval, ClusterShape_MNK, - AccumulatorCopyAtom>, + AccumulatorCopyAtom, + ArchTag>, cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedFastF32< Load2TransformPipelineStageCount, Transform2MmaPipelineStageCount, @@ -286,8 +286,9 @@ struct CollectiveBuilder< ScalingFactor, AccPromotionInterval, ClusterShape_MNK, - AccumulatorCopyAtom> - >; + AccumulatorCopyAtom, + ArchTag> + >; using CollectiveOp = cutlass::gemm::collective::CollectiveMma< DispatchPolicy, TileShape_MNK, diff --git a/include/cutlass/gemm/collective/builders/sm100_blockscaled_mixed_tma_cpasync_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_blockscaled_mixed_tma_cpasync_umma_builder.inl index de9158d4b..72ef10e34 100644 --- a/include/cutlass/gemm/collective/builders/sm100_blockscaled_mixed_tma_cpasync_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_blockscaled_mixed_tma_cpasync_umma_builder.inl @@ -93,6 +93,7 @@ sm100_compute_stage_count_or_override_blockscaled_mixed_tma_cpasync(StageCountAu ///////////////////////////////////////////////////////////////////////////////////////////////// template < + class ArchTag, class ElementPairA, class GmemLayoutATag, int AlignmentA, @@ -106,7 +107,7 @@ template < class BuilderScheduleTag > struct CollectiveBuilder< - arch::Sm100, + ArchTag, arch::OpClassBlockScaledTensorOp, ElementPairA, GmemLayoutATag, @@ -119,7 +120,11 @@ struct CollectiveBuilder< ClusterShape_MNK, // Static cluster shape (_1, _1, _1) StageCountType, BuilderScheduleTag, - cute::enable_if_t > + cute::enable_if_t< + cute::is_same_v && + (cute::is_same_v + ) + > > { using ElementSFA = typename detail::blockscaled::blockscaled_type::sf_type; @@ -238,12 +243,12 @@ struct CollectiveBuilder< CLCPipelineStorage + CLCResponseStorage); // Reduce SMEM capacity available for buffers considering barrier allocations. - static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + static constexpr int ReducedSmemCapacityBytes = detail::sm100_reduced_smem_capacity_bytes(); using SmemTileShape = cute::Shape; static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_blockscaled_mixed_tma_cpasync< - Sm100ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, SmemLayoutAtomSFA, SmemLayoutAtomSFB>(StageCountType{}); + ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, SmemLayoutAtomSFA, SmemLayoutAtomSFB>(StageCountType{}); static_assert(PipelineStages > 0, "Smem usage is too high. Can't create any SMEM buffers for A, B, SFA, and SFB."); using CollectiveOp = cutlass::gemm::collective::CollectiveMma< @@ -251,7 +256,8 @@ struct CollectiveBuilder< PipelineStages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape_MNK>, + ClusterShape_MNK, + ArchTag>, TileShape_MNK, cute::tuple, StridePairA, diff --git a/include/cutlass/gemm/collective/builders/sm100_blockscaled_sparse_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_blockscaled_sparse_umma_builder.inl index 3d19e76da..26d03961e 100644 --- a/include/cutlass/gemm/collective/builders/sm100_blockscaled_sparse_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_blockscaled_sparse_umma_builder.inl @@ -278,8 +278,7 @@ struct CollectiveBuilder< using SmemTileShape = cute::Shape; // Calculate SMEM capacity based on ArchTag - static constexpr int ReducedSmemCapacityBytes = - cutlass::gemm::collective::detail::sm100_smem_capacity_bytes; + static constexpr int ReducedSmemCapacityBytes = detail::sm100_reduced_smem_capacity_bytes(); static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_blockscaled_sparse< ReducedSmemCapacityBytes, @@ -297,7 +296,8 @@ struct CollectiveBuilder< PipelineStages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape_MNK>; + ClusterShape_MNK, + ArchTag>; using CollectiveOp = cutlass::gemm::collective::CollectiveMma< DispatchPolicy, diff --git a/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl index 923ce9cf4..3a45b1cd9 100644 --- a/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl @@ -255,8 +255,7 @@ struct CollectiveBuilder< >::KernelSmemCarveout; // Reduce SMEM capacity available for buffers considering barrier allocations. - static constexpr int ReducedSmemCapacityBytes = - cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + static constexpr int ReducedSmemCapacityBytes = detail::sm100_reduced_smem_capacity_bytes(); using SmemTileShape = cute::Shape; @@ -271,20 +270,23 @@ struct CollectiveBuilder< PipelineStages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape_MNK + ClusterShape_MNK, + ArchTag >, cutlass::gemm::MainloopSm100ArrayTmaUmmaWarpSpecializedBlockScaled< PipelineStages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape_MNK + ClusterShape_MNK, + ArchTag > >, cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedBlockScaled< PipelineStages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape_MNK + ClusterShape_MNK, + ArchTag > >; diff --git a/include/cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl index fc9911d9c..769037d24 100644 --- a/include/cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl @@ -375,8 +375,7 @@ struct CollectiveBuilder< >::KernelSmemCarveout; // Reduce SMEM capacity available for buffers considering barrier allocations. - static constexpr int ReducedSmemCapacityBytes = - cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + static constexpr int ReducedSmemCapacityBytes = detail::sm100_reduced_smem_capacity_bytes(); using SmemTileShape = cute::Shape; using MainloopABPipelineStorage = typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage; @@ -417,12 +416,14 @@ struct CollectiveBuilder< PipelineStages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape_MNK>, + ClusterShape_MNK, + ArchTag>, cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedBlockwiseScaling< PipelineStages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape_MNK>>; + ClusterShape_MNK, + ArchTag>>; using CollectiveOp = cutlass::gemm::collective::CollectiveMma< DispatchPolicy, diff --git a/include/cutlass/gemm/collective/builders/sm100_common.inl b/include/cutlass/gemm/collective/builders/sm100_common.inl index 284def112..0c8be4ad2 100644 --- a/include/cutlass/gemm/collective/builders/sm100_common.inl +++ b/include/cutlass/gemm/collective/builders/sm100_common.inl @@ -1000,6 +1000,18 @@ struct TrivialBlockscaledMma< TileShape_MNK, ClusterShape_MNK, UmmaMajorA, UmmaMajorB, Instr, BuilderScheduleTag>()); }; +template +constexpr int sm100_reduced_smem_capacity_bytes() { + if constexpr (cute::is_same_v) { + return cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + } + else if constexpr (cute::is_same_v) { + return cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + } + else { + static_assert(cutlass::detail::dependent_false, "Invalid ArchTag, only Sm10x are supported."); + } +} } // namespace detail ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/builders/sm100_cpasync_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_cpasync_umma_builder.inl index 34ab79db2..425fbd625 100644 --- a/include/cutlass/gemm/collective/builders/sm100_cpasync_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_cpasync_umma_builder.inl @@ -142,8 +142,7 @@ struct CollectiveBuilder< CLCResponseStorage); // Reduce SMEM capacity available for buffers considering barrier allocations. - static constexpr int ReducedSmemCapacityBytes = - cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + static constexpr int ReducedSmemCapacityBytes = detail::sm100_reduced_smem_capacity_bytes(); using SmemTileShape = cute::Shape; using MainloopPipelineStorage = typename cutlass::PipelineUmmaConsumerAsync<1>::SharedStorage; @@ -156,7 +155,8 @@ struct CollectiveBuilder< PipelineStages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape_MNK>, + ClusterShape_MNK, + ArchTag>, TileShape_MNK, ElementA, cutlass::gemm::TagToStrideA_t, diff --git a/include/cutlass/gemm/collective/builders/sm100_interleaved_complex_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_interleaved_complex_umma_builder.inl index e6801432e..4dc229536 100644 --- a/include/cutlass/gemm/collective/builders/sm100_interleaved_complex_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_interleaved_complex_umma_builder.inl @@ -210,8 +210,7 @@ struct CollectiveBuilder< TensorMapStorage); // Reduce SMEM capacity available for buffers considering extra B smem and barrier smem allocations - static constexpr int ReducedSmemCapacityBytes = - cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + static constexpr int ReducedSmemCapacityBytes = detail::sm100_reduced_smem_capacity_bytes(); using SmemTileShape = cute::Shape; @@ -230,7 +229,8 @@ struct CollectiveBuilder< AccumulatorPipelineStageCount, TransformationStageCount, ClusterShape_MNK, - AccumulatorCopyAtom + AccumulatorCopyAtom, + ArchTag >, cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedInterleavedComplexTF32< PipelineStages, @@ -238,7 +238,8 @@ struct CollectiveBuilder< AccumulatorPipelineStageCount, TransformationStageCount, ClusterShape_MNK, - AccumulatorCopyAtom + AccumulatorCopyAtom, + ArchTag > >; diff --git a/include/cutlass/gemm/collective/builders/sm100_mixed_input_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_mixed_input_umma_builder.inl index 70726b6d0..2e51f3d0b 100644 --- a/include/cutlass/gemm/collective/builders/sm100_mixed_input_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_mixed_input_umma_builder.inl @@ -307,8 +307,7 @@ struct CollectiveBuilder< TensorMapStorage); // Reduce SMEM capacity available for buffers considering extra B smem and barrier smem allocations - static constexpr int ReducedSmemCapacityBytes = - cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + static constexpr int ReducedSmemCapacityBytes = detail::sm100_reduced_smem_capacity_bytes(); static constexpr int ScaleGranularityK = get_ScaleGranularityK(); static constexpr auto stage_info = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_mixed_input< @@ -325,7 +324,8 @@ struct CollectiveBuilder< Transform2MmaPipelineStageCount, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape_MNK + ClusterShape_MNK, + ArchTag >; using CollectiveOp = cutlass::gemm::collective::CollectiveMma< DispatchPolicy, diff --git a/include/cutlass/gemm/collective/builders/sm100_mixed_tma_cpasync_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_mixed_tma_cpasync_umma_builder.inl index d9cb12811..b2648e098 100644 --- a/include/cutlass/gemm/collective/builders/sm100_mixed_tma_cpasync_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_mixed_tma_cpasync_umma_builder.inl @@ -41,6 +41,7 @@ namespace cutlass::gemm::collective { ///////////////////////////////////////////////////////////////////////////////////////////////// template < + class ArchTag, class ElementA, class GmemLayoutATag, int AlignmentA, @@ -54,7 +55,7 @@ template < class BuilderScheduleTag > struct CollectiveBuilder< - arch::Sm100, + ArchTag, arch::OpClassTensorOp, ElementA, GmemLayoutATag, @@ -67,7 +68,9 @@ struct CollectiveBuilder< ClusterShape_MNK, // Static cluster shape (_1, _1, _1) StageCountType, BuilderScheduleTag, - cute::enable_if_t > + cute::enable_if_t && + (cute::is_same_v + )> > { static_assert(cute::is_static_v, "TileShape has to be static"); @@ -135,20 +138,20 @@ struct CollectiveBuilder< CLCPipelineStorage + CLCResponseStorage); // Reduce SMEM capacity available for buffers considering barrier allocations. - static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; - + static constexpr int ReducedSmemCapacityBytes = detail::sm100_reduced_smem_capacity_bytes(); using SmemTileShape = cute::Shape; using MainloopPipelineStorage = typename cutlass::PipelineUmmaConsumerAsync<1>::SharedStorage; static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override< - Sm100ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, MainloopPipelineStorage>(StageCountType{}); + ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, MainloopPipelineStorage>(StageCountType{}); using CollectiveOp = cutlass::gemm::collective::CollectiveMma< cutlass::gemm::MainloopSm100UmmaMixedTmaCpAsyncWarpSpecialized< PipelineStages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape_MNK>, + ClusterShape_MNK, + ArchTag>, TileShape_MNK, ElementA, cutlass::gemm::TagToStrideA_t, diff --git a/include/cutlass/gemm/collective/builders/sm100_planar_complex_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_planar_complex_umma_builder.inl index dee762f11..de58581ad 100644 --- a/include/cutlass/gemm/collective/builders/sm100_planar_complex_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_planar_complex_umma_builder.inl @@ -132,9 +132,7 @@ struct CollectiveBuilder< // Reduce SMEM capacity available for buffers considering barrier allocations. - static constexpr int ReducedSmemCapacityBytes = - cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; - + static constexpr int ReducedSmemCapacityBytes = detail::sm100_reduced_smem_capacity_bytes(); using SmemTileShape = cute::Shape; // Use complex type to calculate SMEM stage count @@ -150,13 +148,15 @@ struct CollectiveBuilder< PipelineStages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape_MNK + ClusterShape_MNK, + ArchTag >, cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedPlanarComplex< PipelineStages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape_MNK + ClusterShape_MNK, + ArchTag > >; diff --git a/include/cutlass/gemm/collective/builders/sm100_sparse_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_sparse_umma_builder.inl index 40dfda2e0..ab30e4806 100644 --- a/include/cutlass/gemm/collective/builders/sm100_sparse_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_sparse_umma_builder.inl @@ -381,8 +381,7 @@ struct CollectiveBuilder< using SmemTileShape = cute::Shape; // Calculate SMEM capacity based on ArchTag - static constexpr int ReducedSmemCapacityBytes = - cutlass::gemm::collective::detail::sm100_smem_capacity_bytes; + static constexpr int ReducedSmemCapacityBytes = detail::sm100_reduced_smem_capacity_bytes(); static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_sparse< ReducedSmemCapacityBytes, @@ -399,7 +398,8 @@ struct CollectiveBuilder< PipelineStages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape_MNK>, + ClusterShape_MNK, + ArchTag>, TileShape_MNK, ElementA, LayoutPairAE, diff --git a/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl index 3e4d8307e..a8d27d527 100644 --- a/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl @@ -288,9 +288,7 @@ struct CollectiveBuilder< IsArrayOfPointersGemm >::KernelSmemCarveout; // Reduce SMEM capacity available for buffers considering barrier allocations. - - static constexpr int ReducedSmemCapacityBytes = - cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + static constexpr int ReducedSmemCapacityBytes = detail::sm100_reduced_smem_capacity_bytes(); using SmemTileShape = cute::Shape; @@ -307,20 +305,23 @@ struct CollectiveBuilder< PipelineStages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape_MNK + ClusterShape_MNK, + ArchTag >, cutlass::gemm::MainloopSm100ArrayTmaUmmaWarpSpecialized< PipelineStages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape_MNK + ClusterShape_MNK, + ArchTag > >, cutlass::gemm::MainloopSm100TmaUmmaWarpSpecialized< PipelineStages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape_MNK + ClusterShape_MNK, + ArchTag > >; diff --git a/include/cutlass/gemm/collective/builders/sm103_blockscaled_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm103_blockscaled_umma_builder.inl index 59864398c..088c7da92 100644 --- a/include/cutlass/gemm/collective/builders/sm103_blockscaled_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm103_blockscaled_umma_builder.inl @@ -498,9 +498,7 @@ struct CollectiveBuilder< TensorMapStorage + TmaPrefetchStorage); // Reduce SMEM capacity available for buffers considering barrier allocations. - static constexpr int ReducedSmemCapacityBytes = - cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; - + static constexpr int ReducedSmemCapacityBytes = detail::sm100_reduced_smem_capacity_bytes(); using SmemTileShape = cute::Shape, Int, _128>; // SmemAllocTypes are uint8_t. We always allocate 128bytes static constexpr auto PipelineStages = cutlass::gemm::collective::detail::sm103_compute_stage_count_or_override_blockscaled< ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, SmemTileShape, SmemLayoutAtomSFA, SmemLayoutAtomSFB>(StageCountType{}); @@ -512,7 +510,8 @@ struct CollectiveBuilder< SchedulerPipelineStageCount, AccumulatorPipelineStageCount, ClusterShape_MNK, - PrefetchType + PrefetchType, + ArchTag >, cutlass::gemm::MainloopSm103TmaUmmaWarpSpecializedBlockScaled< get<0>(PipelineStages), @@ -520,7 +519,8 @@ struct CollectiveBuilder< SchedulerPipelineStageCount, AccumulatorPipelineStageCount, ClusterShape_MNK, - PrefetchType + PrefetchType, + ArchTag > >; diff --git a/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp index edacf6a4f..ba95f2021 100644 --- a/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp @@ -62,6 +62,7 @@ template < int Stages, int SchedulerPipelineStageCount, int AccumulatorPipelineStageCount, + class ArchTag_, class ClusterShape, // Static cluster shape or dynamic (int, int, _1) class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) class ElementPairA_, @@ -82,7 +83,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>, + ClusterShape, + ArchTag_>, TileShape_, ElementPairA_, StridePairA_, @@ -108,7 +110,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>; + ClusterShape, + ArchTag_>; using TileShape = TileShape_; // Due to an MSVC bug, we can't use decltype(make_tiled_mma()) interface. using TiledMMA_SF = TiledMMA, @@ -143,11 +146,6 @@ struct CollectiveMma< using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); - // Multiple buffer the TMA descriptors for each SM so that we can update them asynchronously. - // This should be larger than the total number of TMA requests inflight (from update to issued to returned). - // This can be calculated by SchedulerStages + max(TmaStages) + 2 (for consumer and producer in-flight accessies). - constexpr static uint32_t NumTmaDescriptorsPerSm = SchedulerPipelineStageCount + Stages + 2; - using ElementPairA = ElementPairA_; using ElementPairB = ElementPairB_; using ElementAMma = typename TiledMma::ValTypeA; @@ -264,6 +262,11 @@ struct CollectiveMma< static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + // Multiple buffer the TMA descriptors for each SM so that we can update them asynchronously. + // This should be larger than the total number of TMA requests inflight (from update to issued to returned). + // This can be calculated by SchedulerStages + max(TmaStages) + 2 (for consumer and producer in-flight accessies). + constexpr static uint32_t NumTmaDescriptorsPerSm = IsGroupedGemmKernel ? (SchedulerPipelineStageCount + Stages + 2) : 1; + using TmaInternalElementA = cute::conditional_t; using TmaInternalElementB = cute::conditional_t; diff --git a/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized_rcggemm.hpp b/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized_rcggemm.hpp index b3fc23175..93a683c04 100644 --- a/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized_rcggemm.hpp +++ b/include/cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized_rcggemm.hpp @@ -64,6 +64,7 @@ template < int Stages, int SchedulerPipelineStageCount, int AccumulatorPipelineStageCount, + class ArchTag_, class ClusterShape, // Static cluster shape or dynamic (int, int, _1) class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) class ElementPairA_, @@ -84,7 +85,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>, + ClusterShape, + ArchTag_>, TileShape_, ElementPairA_, StridePairA_, @@ -110,7 +112,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>; + ClusterShape, + ArchTag_>; using TileShape = TileShape_; // Due to an MSVC bug, we can't use decltype(make_tiled_mma()) interface. using TiledMMA_SF = TiledMMA, diff --git a/include/cutlass/gemm/collective/sm100_blockscaled_mma_mixed_tma_cpasync_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_blockscaled_mma_mixed_tma_cpasync_warpspecialized.hpp index 157df9c2b..bbd4a4920 100644 --- a/include/cutlass/gemm/collective/sm100_blockscaled_mma_mixed_tma_cpasync_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm100_blockscaled_mma_mixed_tma_cpasync_warpspecialized.hpp @@ -63,6 +63,7 @@ template < int Stages, int SchedulerPipelineStageCount, int AccumulatorPipelineStageCount, + class ArchTag_, class ClusterShape, // Static cluster shape or dynamic (int, int, _1) class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) class ElementPairA_, @@ -83,7 +84,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>, + ClusterShape, + ArchTag_>, TileShape_, ElementPairA_, StridePairA_, @@ -111,7 +113,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>; + ClusterShape, + ArchTag_>; // TileShape refers to MmaTileShape to adapt for runtime cluster using TileShape = TileShape_; using TiledMma_SF = TiledMMA, diff --git a/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp index e949f31db..cc572b528 100644 --- a/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp @@ -62,6 +62,7 @@ template < int Stages, int SchedulerPipelineStageCount, int AccumulatorPipelineStageCount, + class ArchTag_, class ClusterShape, // Static cluster shape or dynamic (int, int, _1) class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) class ElementPairA_, @@ -82,7 +83,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>, + ClusterShape, + ArchTag_>, TileShape_, ElementPairA_, StridePairA_, @@ -108,7 +110,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>; + ClusterShape, + ArchTag_>; using TileShape = TileShape_; using TiledMMA_SF = TiledMMA, Layout>, diff --git a/include/cutlass/gemm/collective/sm100_blockscaled_sparse_mma_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_blockscaled_sparse_mma_warpspecialized.hpp index e20371d39..bd271ff0c 100644 --- a/include/cutlass/gemm/collective/sm100_blockscaled_sparse_mma_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm100_blockscaled_sparse_mma_warpspecialized.hpp @@ -63,6 +63,7 @@ template < int Stages, int SchedulerPipelineStageCount, int AccumulatorPipelineStageCount, + class ArchTag_, class ClusterShape, // Static cluster shape or dynamic (int, int, _1) class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) class ElementPairA_, @@ -83,7 +84,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>, + ClusterShape, + ArchTag_>, TileShape_, ElementPairA_, LayoutPairA_, @@ -109,7 +111,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>; + ClusterShape, + ArchTag_>; using TileShape = TileShape_; using TiledMMA_SF = TiledMMA, Layout>, diff --git a/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp index 2af0146e1..c380554c9 100644 --- a/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized.hpp @@ -63,6 +63,7 @@ template < int Stages, int SchedulerPipelineStageCount, int AccumulatorPipelineStageCount, + class ArchTag_, class ClusterShape, // Static cluster shape or dynamic (int, int, _1) class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) class ElementA_, @@ -83,7 +84,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>, + ClusterShape, + ArchTag_>, TileShape_, ElementA_, StrideA_, @@ -109,7 +111,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>; + ClusterShape, + ArchTag_>; using TileShape = TileShape_; static constexpr bool IsDynamicCluster = not cute::is_static_v; @@ -123,11 +126,6 @@ struct CollectiveMma< using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); - // Multiple buffer the TMA descriptors for each SM so that we can update them asynchronously. - // This should be larger than the total number of TMA requests inflight (from update to issued to returned). - // This can be calculated by SchedulerStages + max(TmaStages) + 2 (for consumer and producer in-flight accessies). - constexpr static uint32_t NumTmaDescriptorsPerSm = SchedulerPipelineStageCount + Stages + 2; - using ElementA = ElementA_; using ElementAMma = typename TiledMma::ValTypeA; using StrideA = StrideA_; @@ -227,6 +225,11 @@ struct CollectiveMma< static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + // Multiple buffer the TMA descriptors for each SM so that we can update them asynchronously. + // This should be larger than the total number of TMA requests inflight (from update to issued to returned). + // This can be calculated by SchedulerStages + max(TmaStages) + 2 (for consumer and producer in-flight accessies). + constexpr static uint32_t NumTmaDescriptorsPerSm = IsGroupedGemmKernel ? (SchedulerPipelineStageCount + Stages + 2) : 1; + struct SharedStorage { struct TensorStorage : cute::aligned_struct<128, _0> { cute::ArrayEngine> smem_A; diff --git a/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_blockwise_scaling.hpp index 291bb1135..33aad2369 100644 --- a/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_blockwise_scaling.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_blockwise_scaling.hpp @@ -60,6 +60,7 @@ template < int Stages, int SchedulerPipelineStageCount, int AccumulatorPipelineStageCount, + class ArchTag_, class ClusterShape, // Static cluster shape or dynamic (int, int, _1) class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) class ElementA_, @@ -80,7 +81,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>, + ClusterShape, + ArchTag_>, TileShape_, ElementA_, StridePairA_, @@ -106,7 +108,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>; + ClusterShape, + ArchTag_>; using TileShape = TileShape_; static constexpr bool IsDynamicCluster = not cute::is_static_v; diff --git a/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_emulated.hpp b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_emulated.hpp index 15dd91bc6..7248cfab7 100644 --- a/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_emulated.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_emulated.hpp @@ -69,6 +69,7 @@ template < int NumBandsToCompute_, int ScalingFactor_, int AccPromotionInterval_, + class ArchTag_, class AccumulatorCopyAtom_, class ClusterShape, class TileShape_, @@ -93,7 +94,8 @@ struct CollectiveMma< ScalingFactor_, AccPromotionInterval_, ClusterShape, - AccumulatorCopyAtom_>, + AccumulatorCopyAtom_, + ArchTag_>, TileShape_, float, StrideA_, @@ -124,7 +126,8 @@ struct CollectiveMma< ScalingFactor_, AccPromotionInterval_, ClusterShape, - AccumulatorCopyAtom_>; + AccumulatorCopyAtom_, + ArchTag_>; using TileShape = TileShape_; using TiledMma = TiledMma_; static constexpr bool IsDynamicCluster = not cute::is_static_v; diff --git a/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_interleaved_complex_emulated.hpp b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_interleaved_complex_emulated.hpp index 45b9cb436..e636ffcdf 100644 --- a/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_interleaved_complex_emulated.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_interleaved_complex_emulated.hpp @@ -68,6 +68,7 @@ template < int NumBandsToCompute_, int ScalingFactor_, int AccPromotionInterval_, + class ArchTag_, class AccumulatorCopyAtom_, class ClusterShape, class TileShape_, @@ -92,7 +93,8 @@ struct CollectiveMma< ScalingFactor_, AccPromotionInterval_, ClusterShape, - AccumulatorCopyAtom_>, + AccumulatorCopyAtom_, + ArchTag_>, TileShape_, complex, StrideA_, @@ -160,7 +162,8 @@ public: ScalingFactor_, AccPromotionInterval_, ClusterShape, - AccumulatorCopyAtom_>; + AccumulatorCopyAtom_, + ArchTag_>; static constexpr bool IsDynamicCluster = not cute::is_static_v; using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); diff --git a/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_interleaved_complex_tf32.hpp b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_interleaved_complex_tf32.hpp index 33507c9b4..de07149c8 100644 --- a/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_interleaved_complex_tf32.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_interleaved_complex_tf32.hpp @@ -65,6 +65,7 @@ template < int SchedulerPipelineStageCount_, int AccumulatorPipelineStageCount_, int TransformationPipelineStageCount_, + class ArchTag_, class AccumulatorCopyAtom_, class ClusterShape, // Static cluster shape or dynamic (int, int, _1) class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) @@ -86,7 +87,8 @@ struct CollectiveMma< AccumulatorPipelineStageCount_, TransformationPipelineStageCount_, ClusterShape, - AccumulatorCopyAtom_>, + AccumulatorCopyAtom_, + ArchTag_>, TileShape_, complex, StrideA_, @@ -144,7 +146,8 @@ public: AccumulatorPipelineStageCount_, TransformationPipelineStageCount_, ClusterShape, - AccumulatorCopyAtom_>; + AccumulatorCopyAtom_, + ArchTag_>; static constexpr bool IsDynamicCluster = not cute::is_static_v; using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); diff --git a/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_planar_complex.hpp b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_planar_complex.hpp index 44bdc8cf5..bb7163578 100644 --- a/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_planar_complex.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_planar_complex.hpp @@ -64,6 +64,7 @@ template < int Stages, int SchedulerPipelineStageCount, int AccumulatorPipelineStageCount, + class ArchTag_, class ClusterShape, class TileShape_, // Static cluster shape or dynamic (int, int, _1) class ElementA_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) @@ -84,7 +85,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>, + ClusterShape, + ArchTag_>, TileShape_, ElementA_, StrideA_, @@ -116,7 +118,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>; + ClusterShape, + ArchTag_>; using TileShape = TileShape_; CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), @@ -128,15 +131,17 @@ struct CollectiveMma< using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); - // Multiple buffer the TMA descriptors for each SM so that we can update them asynchronously. - // This should be larger than the total number of TMA requests inflight (from update to issued to returned). - // This can be calculated by SchedulerStages + max(TmaStages) + 2 (for consumer and producer in-flight accessies). - constexpr static uint32_t NumTmaDescriptorsPerSm = SchedulerPipelineStageCount + Stages + 2; - using ElementA = ElementA_; using ElementAMma = typename TiledMma::ValTypeA; using StrideA = StrideA_; using InternalStrideA = cute::remove_pointer_t; + + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + // Multiple buffer the TMA descriptors for each SM so that we can update them asynchronously. + // This should be larger than the total number of TMA requests inflight (from update to issued to returned). + // This can be calculated by SchedulerStages + max(TmaStages) + 2 (for consumer and producer in-flight accessies). + constexpr static uint32_t NumTmaDescriptorsPerSm = IsGroupedGemmKernel ? (SchedulerPipelineStageCount + Stages + 2) : 1; using ElementB = ElementB_; using ElementBMma = typename TiledMma::ValTypeB; using StrideB = StrideB_; diff --git a/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_rcggemm.hpp b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_rcggemm.hpp index 9bd85aeea..adf39e63d 100644 --- a/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_rcggemm.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_array_warpspecialized_rcggemm.hpp @@ -65,6 +65,7 @@ template < int Stages, int SchedulerPipelineStageCount, int AccumulatorPipelineStageCount, + class ArchTag_, class ClusterShape, // Static cluster shape or dynamic (int, int, _1) class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) class ElementA_, @@ -85,7 +86,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>, + ClusterShape, + ArchTag_>, TileShape_, ElementA_, StrideA_, @@ -111,7 +113,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>; + ClusterShape, + ArchTag_>; using TileShape = TileShape_; static constexpr bool IsDynamicCluster = not cute::is_static_v; diff --git a/include/cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp index f210d5d7f..47d12be81 100644 --- a/include/cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_cpasync_warpspecialized.hpp @@ -60,6 +60,7 @@ template < int Stages, int SchedulerPipelineStageCount, int AccumulatorPipelineStageCount, + class ArchTag_, class ClusterShape, // Static cluster shape or dynamic (int, int, _1) class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) class ElementA_, @@ -80,7 +81,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>, + ClusterShape, + ArchTag_>, TileShape_, ElementA_, StrideA_, @@ -107,7 +109,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>; + ClusterShape, + ArchTag_>; // TileShape refers to MmaTileShape to adapt for runtime cluster shape using TileShape = TileShape_; diff --git a/include/cutlass/gemm/collective/sm100_mma_mixed_tma_cpasync_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_mma_mixed_tma_cpasync_warpspecialized.hpp index 2a0569395..945e5feee 100644 --- a/include/cutlass/gemm/collective/sm100_mma_mixed_tma_cpasync_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_mixed_tma_cpasync_warpspecialized.hpp @@ -62,6 +62,7 @@ template < int Stages, int SchedulerPipelineStageCount, int AccumulatorPipelineStageCount, + class ArchTag_, class ClusterShape, // Static cluster shape or dynamic (int, int, _1) class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) class ElementA_, @@ -82,7 +83,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>, + ClusterShape, + ArchTag_>, TileShape_, ElementA_, StrideA_, @@ -111,7 +113,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>; + ClusterShape, + ArchTag_>; // TileShape refers to MmaTileShape to adapt for runtime cluster using TileShape = TileShape_; diff --git a/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp index 4644eaeba..3390f5c4a 100644 --- a/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_warpspecialized.hpp @@ -62,6 +62,7 @@ template < int Stages, int SchedulerPipelineStageCount, int AccumulatorPipelineStageCount, + class ArchTag_, class ClusterShape, // Static cluster shape or dynamic (int, int, _1) class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) class ElementA_, @@ -82,7 +83,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>, + ClusterShape, + ArchTag_>, TileShape_, ElementA_, StrideA_, @@ -108,7 +110,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>; + ClusterShape, + ArchTag_>; using TileShape = TileShape_; static constexpr bool IsDynamicCluster = not cute::is_static_v; diff --git a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp index 3f008c051..6651f9abf 100644 --- a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp @@ -62,6 +62,7 @@ template < int Stages, int SchedulerPipelineStageCount, int AccumulatorPipelineStageCount, + class ArchTag_, class ClusterShape, // Static cluster shape or dynamic (int, int, _1) class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) class ElementA_, @@ -82,7 +83,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>, + ClusterShape, + ArchTag_>, TileShape_, ElementA_, StridePairA_, @@ -108,7 +110,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>; + ClusterShape, + ArchTag_>; using TileShape = TileShape_; using ElementA = ElementA_; diff --git a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_emulated.hpp b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_emulated.hpp index 1be80601e..1b2ac920e 100644 --- a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_emulated.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_emulated.hpp @@ -82,6 +82,7 @@ template < int NumBandsToCompute_, int ScalingFactor_, int AccPromotionInterval_, + class ArchTag_, class AccumulatorCopyAtom_, class ClusterShape, class TileShape_, @@ -106,7 +107,8 @@ struct CollectiveMma< ScalingFactor_, AccPromotionInterval_, ClusterShape, - AccumulatorCopyAtom_>, + AccumulatorCopyAtom_, + ArchTag_>, TileShape_, float, StrideA_, @@ -137,7 +139,8 @@ struct CollectiveMma< ScalingFactor_, AccPromotionInterval_, ClusterShape, - AccumulatorCopyAtom_>; + AccumulatorCopyAtom_, + ArchTag_>; using TileShape = TileShape_; using TiledMma = TiledMma_; static constexpr bool IsDynamicCluster = not cute::is_static_v; diff --git a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_interleaved_complex_emulated.hpp b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_interleaved_complex_emulated.hpp index a8fe8c4a0..813d3561c 100644 --- a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_interleaved_complex_emulated.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_interleaved_complex_emulated.hpp @@ -66,6 +66,7 @@ template < int NumBandsToCompute_, int ScalingFactor_, int AccPromotionInterval_, + class ArchTag_, class AccumulatorCopyAtom_, class ClusterShape, class TileShape_, @@ -90,7 +91,8 @@ struct CollectiveMma< ScalingFactor_, AccPromotionInterval_, ClusterShape, - AccumulatorCopyAtom_>, + AccumulatorCopyAtom_, + ArchTag_>, TileShape_, complex, StrideA_, @@ -156,7 +158,8 @@ public: ScalingFactor_, AccPromotionInterval_, ClusterShape, - AccumulatorCopyAtom_>; + AccumulatorCopyAtom_, + ArchTag_>; static constexpr bool IsDynamicCluster = not cute::is_static_v; using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); diff --git a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_interleaved_complex_tf32.hpp b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_interleaved_complex_tf32.hpp index a1f25017a..bf8a7fa00 100644 --- a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_interleaved_complex_tf32.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_interleaved_complex_tf32.hpp @@ -78,6 +78,7 @@ template < int SchedulerPipelineStageCount_, int AccumulatorPipelineStageCount_, int TransformationPipelineStageCount_, + class ArchTag_, class AccumulatorCopyAtom_, class ClusterShape, // Static cluster shape or dynamic (int, int, _1) class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) @@ -99,7 +100,8 @@ struct CollectiveMma< AccumulatorPipelineStageCount_, TransformationPipelineStageCount_, ClusterShape, - AccumulatorCopyAtom_>, + AccumulatorCopyAtom_, + ArchTag_>, TileShape_, complex, StrideA_, @@ -155,7 +157,8 @@ public: AccumulatorPipelineStageCount_, TransformationPipelineStageCount_, ClusterShape, - AccumulatorCopyAtom_>; + AccumulatorCopyAtom_, + ArchTag_>; static constexpr bool IsDynamicCluster = not cute::is_static_v; using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); diff --git a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp index 97fbd3367..8530388b8 100644 --- a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_mixed_input.hpp @@ -65,6 +65,7 @@ template < int Transform2MmaPipelineStageCount_, int SchedulerPipelineStageCount_, int AccumulatorPipelineStageCount_, + class ArchTag_, class ClusterShape, class TileShape_, class ElementAOptionalTuple_, @@ -86,7 +87,8 @@ struct CollectiveMma< Transform2MmaPipelineStageCount_, SchedulerPipelineStageCount_, AccumulatorPipelineStageCount_, - ClusterShape>, + ClusterShape, + ArchTag_>, TileShape_, ElementAOptionalTuple_, StridePairA_, @@ -115,7 +117,8 @@ public: Transform2MmaPipelineStageCount_, SchedulerPipelineStageCount_, AccumulatorPipelineStageCount_, - ClusterShape>; + ClusterShape, + ArchTag_>; using TileShape = TileShape_; using TiledMma = TiledMma_; using KernelSchedule = typename DispatchPolicy::Schedule; diff --git a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_planar_complex.hpp b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_planar_complex.hpp index 3d9c11e49..eabad12d0 100644 --- a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_planar_complex.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_planar_complex.hpp @@ -70,6 +70,7 @@ template < int Stages, int SchedulerPipelineStageCount, int AccumulatorPipelineStageCount, + class ArchTag_, class ClusterShape, class TileShape_, // Static cluster shape or dynamic (int, int, _1) class ElementA_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) @@ -90,7 +91,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>, + ClusterShape, + ArchTag_>, TileShape_, ElementA_, StrideA_, @@ -122,7 +124,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>; + ClusterShape, + ArchTag_>; using TileShape = TileShape_; CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), diff --git a/include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp index 185ab050a..b9e0fb2a8 100644 --- a/include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm100_sparse_mma_warpspecialized.hpp @@ -61,6 +61,7 @@ template < int Stages, int SchedulerPipelineStageCount, int AccumulatorPipelineStageCount, + class ArchTag_, class ClusterShape, // Static cluster shape or dynamic (int, int, _1) class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) class ElementA_, @@ -81,7 +82,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>, + ClusterShape, + ArchTag_>, TileShape_, ElementA_, LayoutPairAE_, @@ -107,7 +109,8 @@ struct CollectiveMma< Stages, SchedulerPipelineStageCount, AccumulatorPipelineStageCount, - ClusterShape>; + ClusterShape, + ArchTag_>; using TileShape = TileShape_; static constexpr bool IsDynamicCluster = not cute::is_static_v; diff --git a/include/cutlass/gemm/collective/sm103_blockscaled_mma_array_warpspecialized.hpp b/include/cutlass/gemm/collective/sm103_blockscaled_mma_array_warpspecialized.hpp index 2030b9a5d..602971d75 100644 --- a/include/cutlass/gemm/collective/sm103_blockscaled_mma_array_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm103_blockscaled_mma_array_warpspecialized.hpp @@ -66,6 +66,7 @@ template < int LoadSFPipelineStageCount, int SchedulerPipelineStageCount, int AccumulatorPipelineStageCount, + class ArchTag_, class ClusterShape, // Static cluster shape or dynamic (int, int, int) cutlass::sm103::detail::KernelPrefetchType PrefetchType, class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) @@ -89,7 +90,8 @@ struct CollectiveMma< SchedulerPipelineStageCount, AccumulatorPipelineStageCount, ClusterShape, - PrefetchType>, + PrefetchType, + ArchTag_>, TileShape_, ElementPairA_, StridePairA_, @@ -117,7 +119,8 @@ struct CollectiveMma< SchedulerPipelineStageCount, AccumulatorPipelineStageCount, ClusterShape, - PrefetchType>; + PrefetchType, + ArchTag_>; using TileShape = TileShape_; // Due to an MSVC bug, we can't use decltype(make_tiled_mma()) interface. diff --git a/include/cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp b/include/cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp index 920d7e7e8..b8a21c4e3 100644 --- a/include/cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm103_blockscaled_mma_warpspecialized.hpp @@ -66,6 +66,7 @@ template < int LoadSFPipelineStageCount, int SchedulerPipelineStageCount, int AccumulatorPipelineStageCount, + class ArchTag_, class ClusterShape, // Static cluster shape or dynamic (int, int, int) cutlass::sm103::detail::KernelPrefetchType PrefetchType, class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) @@ -89,7 +90,8 @@ struct CollectiveMma< SchedulerPipelineStageCount, AccumulatorPipelineStageCount, ClusterShape, - PrefetchType>, + PrefetchType, + ArchTag_>, TileShape_, ElementPairA_, StridePairA_, @@ -117,7 +119,8 @@ struct CollectiveMma< SchedulerPipelineStageCount, AccumulatorPipelineStageCount, ClusterShape, - PrefetchType>; + PrefetchType, + ArchTag_>; using TileShape = TileShape_; // Due to an MSVC bug, we can't use decltype(make_tiled_mma()) interface. diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index 3ea713993..7d3043e09 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -976,12 +976,13 @@ template< int Stages_, int SchedulerPipelineStageCount_, int AccumulatorPipelineStageCount_, - class ClusterShape_ = Shape<_1,_1,_1> + class ClusterShape_ = Shape<_1,_1,_1>, + class ArchTag_ = arch::Sm100 > struct MainloopSm100UmmaCpAsyncWarpSpecialized { constexpr static int Stages = Stages_; using ClusterShape = ClusterShape_; - using ArchTag = arch::Sm100; + using ArchTag = ArchTag_; using Schedule = KernelWarpSpecializedSm100; }; @@ -989,12 +990,13 @@ template< int Stages_, int SchedulerPipelineStageCount_, int AccumulatorPipelineStageCount_, - class ClusterShape_ = Shape<_1,_1,_1> + class ClusterShape_ = Shape<_1,_1,_1>, + class ArchTag_ = arch::Sm100 > struct MainloopSm100UmmaMixedTmaCpAsyncWarpSpecialized { constexpr static int Stages = Stages_; using ClusterShape = ClusterShape_; - using ArchTag = arch::Sm100; + using ArchTag = ArchTag_; using Schedule = KernelMixedTmaCpAsyncWarpSpecializedSm100; constexpr static bool IsOverlappingAccum = false; }; @@ -1003,12 +1005,13 @@ template< int Stages_, int SchedulerPipelineStageCount_, int AccumulatorPipelineStageCount_, - class ClusterShape_ = Shape<_1,_1,_1> + class ClusterShape_ = Shape<_1,_1,_1>, + class ArchTag_ = arch::Sm100 > struct MainloopSm100UmmaMixedTmaCpAsyncWarpSpecializedBlockScaled { constexpr static int Stages = Stages_; using ClusterShape = ClusterShape_; - using ArchTag = arch::Sm100; + using ArchTag = ArchTag_; using Schedule = KernelMixedTmaCpAsyncWarpSpecializedSm100; constexpr static bool IsOverlappingAccum = false; }; @@ -1018,12 +1021,13 @@ template< int Stages_, int SchedulerPipelineStageCount_, int AccumulatorPipelineStageCount_, - class ClusterShape_ = Shape<_1,_1,_1> + class ClusterShape_ = Shape<_1,_1,_1>, + class ArchTag_ = arch::Sm100 > struct MainloopSm100TmaUmmaWarpSpecialized { constexpr static int Stages = Stages_; using ClusterShape = ClusterShape_; - using ArchTag = arch::Sm100; + using ArchTag = ArchTag_; using Schedule = KernelTmaWarpSpecializedSm100; constexpr static bool IsOverlappingAccum = false; }; @@ -1033,12 +1037,13 @@ template< int Stages_, int SchedulerPipelineStageCount_, int AccumulatorPipelineStageCount_, - class ClusterShape_ = Shape<_1,_1,_1> + class ClusterShape_ = Shape<_1,_1,_1>, + class ArchTag_ = arch::Sm100 > struct MainloopSm100TmaUmmaWarpSpecializedBlockwiseScaling { constexpr static int Stages = Stages_; using ClusterShape = ClusterShape_; - using ArchTag = arch::Sm100; + using ArchTag = ArchTag_; using Schedule = KernelTmaWarpSpecializedMmaTransformSm100; constexpr static bool IsOverlappingAccum = false; }; @@ -1048,12 +1053,13 @@ template< int Stages_, int SchedulerPipelineStageCount_, int AccumulatorPipelineStageCount_, - class ClusterShape_ = Shape<_1,_1,_1> + class ClusterShape_ = Shape<_1,_1,_1>, + class ArchTag_ = arch::Sm100 > struct MainloopSm100ArrayTmaUmmaWarpSpecializedBlockwiseScaling { constexpr static int Stages = Stages_; using ClusterShape = ClusterShape_; - using ArchTag = arch::Sm100; + using ArchTag = ArchTag_; using Schedule = KernelPtrArrayTmaWarpSpecializedMmaTransformSm100; constexpr static bool IsOverlappingAccum = false; }; @@ -1063,12 +1069,13 @@ template< int Stages_, int SchedulerPipelineStageCount_, int AccumulatorPipelineStageCount_, - class ClusterShape_ = Shape<_1,_1,_1> + class ClusterShape_ = Shape<_1,_1,_1>, + class ArchTag_ = arch::Sm100 > struct MainloopSm100TmaUmmaWarpSpecializedBlockScaled { constexpr static int Stages = Stages_; using ClusterShape = ClusterShape_; - using ArchTag = arch::Sm100; + using ArchTag = ArchTag_; constexpr static bool IsOverlappingAccum = AccumulatorPipelineStageCount_ == 1; using Schedule = KernelTmaWarpSpecializedBlockScaledSm100; }; @@ -1077,13 +1084,14 @@ template< int Stages_, int SchedulerPipelineStageCount_, int AccumulatorPipelineStageCount_, - class ClusterShape_ = Shape<_1,_1,_1> + class ClusterShape_ = Shape<_1,_1,_1>, + class ArchTag_ = arch::Sm100 > struct MainloopSm100TmaUmmaWarpSpecializedSparse { constexpr static int Stages = Stages_; constexpr static int MetadataS2TStages = 4; using ClusterShape = ClusterShape_; - using ArchTag = arch::Sm100; + using ArchTag = ArchTag_; constexpr static bool IsOverlappingAccum = AccumulatorPipelineStageCount_ == 1; using Schedule = KernelSparseTmaWarpSpecializedSm100; }; @@ -1092,13 +1100,14 @@ template< int Stages_, int SchedulerPipelineStageCount_, int AccumulatorPipelineStageCount_, - class ClusterShape_ = Shape<_1,_1,_1> + class ClusterShape_ = Shape<_1,_1,_1>, + class ArchTag_ = arch::Sm100 > struct MainloopSm100TmaUmmaWarpSpecializedBlockScaledSparse { constexpr static int Stages = Stages_; constexpr static int MetadataS2TStages = 4; using ClusterShape = ClusterShape_; - using ArchTag = arch::Sm100; + using ArchTag = ArchTag_; constexpr static bool IsOverlappingAccum = AccumulatorPipelineStageCount_ == 1; using Schedule = KernelSparseTmaWarpSpecializedBlockScaledSm100; }; @@ -1132,7 +1141,8 @@ template< class ClusterShape_ = Shape<_1,_1,_1>, // The TMEM_LOAD atom to be used for loading local accumulator // from TMEM to registers - class AccumulatorCopyAtom_ = cute::SM100_TMEM_LOAD_32dp32b32x + class AccumulatorCopyAtom_ = cute::SM100_TMEM_LOAD_32dp32b32x, + class ArchTag_ = arch::Sm100 > struct MainloopSm100TmaUmmaWarpSpecializedFastF32 { constexpr static int Load2TransformPipelineStageCount = Load2TransformPipelineStageCount_; @@ -1143,7 +1153,7 @@ struct MainloopSm100TmaUmmaWarpSpecializedFastF32 { constexpr static detail::KernelInputTransformType InputTransformType = detail::KernelInputTransformType::FastF32; using ClusterShape = ClusterShape_; using AccumulatorCopyAtom = AccumulatorCopyAtom_; - using ArchTag = arch::Sm100; + using ArchTag = ArchTag_; using Schedule = KernelTmaWarpSpecializedInputTransformSm100; // For backwards compatibility with GemmUniversalAdapter. @@ -1163,7 +1173,8 @@ template< // Transformation <-> MMA int TransformationPipelineStageCount_, class ClusterShape_ = Shape<_1,_1,_1>, - class AccumulatorCopyAtom_ = cute::SM100_TMEM_LOAD_16dp256b1x + class AccumulatorCopyAtom_ = cute::SM100_TMEM_LOAD_16dp256b1x, + class ArchTag_ = arch::Sm100 > struct MainloopSm100TmaUmmaWarpSpecializedInterleavedComplexTF32 { constexpr static int ComputationPipelineStageCount = ComputationPipelineStageCount_; @@ -1171,7 +1182,7 @@ struct MainloopSm100TmaUmmaWarpSpecializedInterleavedComplexTF32 { constexpr static detail::KernelInputTransformType InputTransformType = detail::KernelInputTransformType::InterleavedComplexTF32; using ClusterShape = ClusterShape_; using AccumulatorCopyAtom = AccumulatorCopyAtom_; - using ArchTag = arch::Sm100; + using ArchTag = ArchTag_; using Schedule = KernelTmaWarpSpecializedInputTransformSm100; // For backwards compatibility with GemmUniversalAdapter. @@ -1191,7 +1202,8 @@ template< // Accmulator pipeline depth int AccumulatorPipelineStageCount_, // ClusterShape for the kernel - class ClusterShape_ = Shape<_1,_1,_1> + class ClusterShape_ = Shape<_1,_1,_1>, + class ArchTag_ = arch::Sm100 > struct MainloopSm100TmaUmmaWarpSpecializedMixedInput { constexpr static int Load2TransformPipelineStageCount = Load2TransformPipelineStageCount_; @@ -1199,7 +1211,7 @@ struct MainloopSm100TmaUmmaWarpSpecializedMixedInput { constexpr static int Transform2MmaPipelineStageCount = Transform2MmaPipelineStageCount_; constexpr static detail::KernelInputTransformType InputTransformType = detail::KernelInputTransformType::MixedInput; using ClusterShape = ClusterShape_; - using ArchTag = arch::Sm100; + using ArchTag = ArchTag_; using Schedule = KernelTmaWarpSpecializedMixedInputTransformSm100; // For backwards compatibility with GemmUniversalAdapter. @@ -1214,12 +1226,13 @@ template< int SchedulerPipelineStageCount_, // Accmulator pipeline depth int AccumulatorPipelineStageCount_, - class ClusterShape_ = Shape<_1,_1,_1> + class ClusterShape_ = Shape<_1,_1,_1>, + class ArchTag_ = arch::Sm100 > struct MainloopSm100TmaUmmaWarpSpecializedPlanarComplex { constexpr static int Stages = Stages_; using ClusterShape = ClusterShape_; - using ArchTag = arch::Sm100; + using ArchTag = ArchTag_; using Schedule = KernelTmaWarpSpecializedSm100; constexpr static bool IsOverlappingAccum = false; }; @@ -1229,12 +1242,13 @@ template< int Stages_, int SchedulerPipelineStageCount_, int AccumulatorPipelineStageCount_, - class ClusterShape_ = Shape<_1,_1,_1> + class ClusterShape_ = Shape<_1,_1,_1>, + class ArchTag_ = arch::Sm100 > struct MainloopSm100ArrayTmaUmmaWarpSpecialized { constexpr static int Stages = Stages_; using ClusterShape = ClusterShape_; - using ArchTag = arch::Sm100; + using ArchTag = ArchTag_; constexpr static bool IsOverlappingAccum = false; using Schedule = KernelPtrArrayTmaWarpSpecializedSm100; }; @@ -1244,12 +1258,13 @@ template< int Stages_, int SchedulerPipelineStageCount_, int AccumulatorPipelineStageCount_, - class ClusterShape_ = Shape<_1,_1,_1> + class ClusterShape_ = Shape<_1,_1,_1>, + class ArchTag_ = arch::Sm100 > struct MainloopSm100RCGroupGemmTmaUmmaWarpSpecialized { constexpr static int Stages = Stages_; using ClusterShape = ClusterShape_; - using ArchTag = arch::Sm100; + using ArchTag = ArchTag_; constexpr static bool IsOverlappingAccum = false; using Schedule = KernelPtrArrayTmaWarpSpecializedSm100; }; @@ -1259,12 +1274,13 @@ template< int Stages_, int SchedulerPipelineStageCount_, int AccumulatorPipelineStageCount_, - class ClusterShape_ = Shape<_1,_1,_1> + class ClusterShape_ = Shape<_1,_1,_1>, + class ArchTag_ = arch::Sm100 > struct MainloopSm100RCGroupGemmTmaUmmaWarpSpecializedBlockScaled { constexpr static int Stages = Stages_; using ClusterShape = ClusterShape_; - using ArchTag = arch::Sm100; + using ArchTag = ArchTag_; constexpr static bool IsOverlappingAccum = AccumulatorPipelineStageCount_ == 1; using Schedule = KernelPtrArrayTmaWarpSpecializedBlockScaledSm100; }; @@ -1274,12 +1290,13 @@ template< int Stages_, int SchedulerPipelineStageCount_, int AccumulatorPipelineStageCount_, - class ClusterShape_ = Shape<_1,_1,_1> + class ClusterShape_ = Shape<_1,_1,_1>, + class ArchTag_ = arch::Sm100 > struct MainloopSm100ArrayTmaUmmaWarpSpecializedBlockScaled { constexpr static int Stages = Stages_; using ClusterShape = ClusterShape_; - using ArchTag = arch::Sm100; + using ArchTag = ArchTag_; constexpr static bool IsOverlappingAccum = AccumulatorPipelineStageCount_ == 1; using Schedule = KernelPtrArrayTmaWarpSpecializedBlockScaledSm100; }; @@ -1291,12 +1308,13 @@ template< int Stages_, int SchedulerPipelineStageCount_, int AccumulatorPipelineStageCount_, - class ClusterShape_ = Shape<_1,_1,_1> + class ClusterShape_ = Shape<_1,_1,_1>, + class ArchTag_ = arch::Sm100 > struct MainloopSm100ArrayTmaUmmaWarpSpecializedPlanarComplex { constexpr static int Stages = Stages_; using ClusterShape = ClusterShape_; - using ArchTag = arch::Sm100; + using ArchTag = ArchTag_; constexpr static bool IsOverlappingAccum = false; using Schedule = KernelPtrArrayTmaWarpSpecializedSm100; }; @@ -1331,7 +1349,8 @@ template< class ClusterShape_ = Shape<_1,_1,_1>, // The TMEM_LOAD atom to be used for loading local accumulator // from TMEM to registers - class AccumulatorCopyAtom_ = cute::SM100_TMEM_LOAD_32dp32b32x + class AccumulatorCopyAtom_ = cute::SM100_TMEM_LOAD_32dp32b32x, + class ArchTag_ = arch::Sm100 > struct MainloopSm100ArrayTmaUmmaWarpSpecializedFastF32 { constexpr static int Load2TransformPipelineStageCount = Load2TransformPipelineStageCount_; @@ -1342,7 +1361,7 @@ struct MainloopSm100ArrayTmaUmmaWarpSpecializedFastF32 { constexpr static detail::KernelInputTransformType InputTransformType = detail::KernelInputTransformType::FastF32; using ClusterShape = ClusterShape_; using AccumulatorCopyAtom = AccumulatorCopyAtom_; - using ArchTag = arch::Sm100; + using ArchTag = ArchTag_; using Schedule = KernelPtrArrayTmaWarpSpecializedInputTransformSm100; // For backwards compatibility with GemmUniversalAdapter. @@ -1362,7 +1381,8 @@ template< // Transformation <-> MMA int TransformationPipelineStageCount_, class ClusterShape_ = Shape<_1,_1,_1>, - class AccumulatorCopyAtom_ = cute::SM100_TMEM_LOAD_16dp256b1x + class AccumulatorCopyAtom_ = cute::SM100_TMEM_LOAD_16dp256b1x, + class ArchTag_ = arch::Sm100 > struct MainloopSm100ArrayTmaUmmaWarpSpecializedInterleavedComplexTF32 { constexpr static int ComputationPipelineStageCount = ComputationPipelineStageCount_; @@ -1370,7 +1390,7 @@ struct MainloopSm100ArrayTmaUmmaWarpSpecializedInterleavedComplexTF32 { constexpr static detail::KernelInputTransformType InputTransformType = detail::KernelInputTransformType::InterleavedComplexTF32; using ClusterShape = ClusterShape_; using AccumulatorCopyAtom = AccumulatorCopyAtom_; - using ArchTag = arch::Sm100; + using ArchTag = ArchTag_; using Schedule = KernelPtrArrayTmaWarpSpecializedInputTransformSm100; // For backwards compatibility with GemmUniversalAdapter. @@ -1385,13 +1405,14 @@ template< int SchedulerPipelineStageCount_, int AccumulatorPipelineStageCount_, class ClusterShape_ = Shape<_1,_1,_1>, - cutlass::sm103::detail::KernelPrefetchType PrefetchType_ = cutlass::sm103::detail::KernelPrefetchType::TmaPrefetch + cutlass::sm103::detail::KernelPrefetchType PrefetchType_ = cutlass::sm103::detail::KernelPrefetchType::TmaPrefetch, + class ArchTag_ = arch::Sm103 > struct MainloopSm103TmaUmmaWarpSpecializedBlockScaled { constexpr static int LoadABPipelineStageCount = LoadABPipelineStageCount_; constexpr static int LoadSFPipelineStageCount = LoadSFPipelineStageCount_; using ClusterShape = ClusterShape_; - using ArchTag = arch::Sm103; + using ArchTag = ArchTag_; constexpr static bool IsOverlappingAccum = AccumulatorPipelineStageCount_ == 1; using Schedule = KernelTmaWarpSpecializedBlockScaledSm103; // For backwards compatibility with GemmUniversalAdapter. @@ -1407,13 +1428,14 @@ template< int SchedulerPipelineStageCount_, int AccumulatorPipelineStageCount_, class ClusterShape_ = Shape<_1,_1,_1>, - cutlass::sm103::detail::KernelPrefetchType PrefetchType_ = cutlass::sm103::detail::KernelPrefetchType::TmaPrefetch + cutlass::sm103::detail::KernelPrefetchType PrefetchType_ = cutlass::sm103::detail::KernelPrefetchType::TmaPrefetch, + class ArchTag_ = arch::Sm103 > struct MainloopSm103ArrayTmaUmmaWarpSpecializedBlockScaled { constexpr static int LoadABPipelineStageCount = LoadABPipelineStageCount_; constexpr static int LoadSFPipelineStageCount = LoadSFPipelineStageCount_; using ClusterShape = ClusterShape_; - using ArchTag = arch::Sm103; + using ArchTag = ArchTag_; constexpr static bool IsOverlappingAccum = AccumulatorPipelineStageCount_ == 1; using Schedule = KernelPtrArrayTmaWarpSpecializedBlockScaledSm103; // For backwards compatibility with GemmUniversalAdapter. diff --git a/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized.hpp index 18d1cd617..ca5597946 100644 --- a/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized.hpp @@ -144,7 +144,7 @@ public: using TileSchedulerParams = typename TileScheduler::Params; static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent; - static constexpr bool IsTensorMapUpdateAsync = not IsSchedDynamicPersistent; + static constexpr bool IsTensorMapUpdateAsync = IsGroupedGemmKernel && not IsSchedDynamicPersistent; static constexpr bool IsDynamicCluster = not cute::is_static_v; static constexpr uint32_t MinTensorMapWorkspaceAlignment = 64; diff --git a/include/cutlass/gemm/kernel/sm100_static_tile_scheduler.hpp b/include/cutlass/gemm/kernel/sm100_static_tile_scheduler.hpp index 776026f70..b3719d037 100644 --- a/include/cutlass/gemm/kernel/sm100_static_tile_scheduler.hpp +++ b/include/cutlass/gemm/kernel/sm100_static_tile_scheduler.hpp @@ -44,8 +44,18 @@ public StaticPersistentTileScheduler< public: using BaseScheduler = StaticPersistentTileScheduler; public: - using BaseScheduler::StaticPersistentTileScheduler; using Params = PersistentTileSchedulerSm90Params; + + // Explicit forwarding constructors replacing inheriting-constructor syntax + // (`using BaseScheduler::StaticPersistentTileScheduler;`) which newer CUDA + // host compilers reject in dependent-base contexts: the injected-class-name + // resolves to a type rather than a constructor name. + CUTLASS_HOST_DEVICE + StaticPersistentTileScheduler100() = default; + + CUTLASS_DEVICE explicit + StaticPersistentTileScheduler100(Params const& params_) + : BaseScheduler(params_) {} using RasterOrder = typename Params::RasterOrder; using RasterOrderOptions = typename Params::RasterOrderOptions; struct CLCResponse { uint32_t data[4] = {0}; }; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp index 587a5f4bc..c6a94cbf5 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp @@ -242,8 +242,10 @@ public: CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid cluster count to " << max_active_clusters); } - KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count, max_active_clusters}; - + KernelHardwareInfo hw_info = args.hw_info; + hw_info.sm_count = sm_count; + hw_info.max_active_clusters = max_active_clusters; + // Calculate workspace pointers uint8_t* workspace_ptr = reinterpret_cast(workspace); size_t workspace_offset = 0; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp index 1734c1648..1fd9fb0ec 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp @@ -245,8 +245,10 @@ public: CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid cluster count to " << max_active_clusters); } - KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count, max_active_clusters}; - + KernelHardwareInfo hw_info = args.hw_info; + hw_info.sm_count = sm_count; + hw_info.max_active_clusters = max_active_clusters; + // Calculate workspace pointers uint8_t* workspace_ptr = reinterpret_cast(workspace); size_t workspace_offset = 0; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp index 26e1b09f2..2eaab2458 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp @@ -193,8 +193,10 @@ public: CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid cluster count to " << max_active_clusters); } - KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count, max_active_clusters}; - + KernelHardwareInfo hw_info = args.hw_info; + hw_info.sm_count = sm_count; + hw_info.max_active_clusters = max_active_clusters; + TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments( problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, workspace); diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp index 379d4cb7f..d44839884 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp @@ -206,8 +206,10 @@ public: CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid cluster count to " << max_active_clusters); } - KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count, max_active_clusters}; - + KernelHardwareInfo hw_info = args.hw_info; + hw_info.sm_count = sm_count; + hw_info.max_active_clusters = max_active_clusters; + TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments( problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, workspace); diff --git a/include/cutlass/kernel_hardware_info.h b/include/cutlass/kernel_hardware_info.h index ce09b9d2b..6b8c2427b 100644 --- a/include/cutlass/kernel_hardware_info.h +++ b/include/cutlass/kernel_hardware_info.h @@ -81,21 +81,26 @@ struct KernelHardwareInfo { } // Query maximum number of active clusters that could co-exist on the target device - // based on kernel properties such as cluster dims and threadblock dims + // based on kernel properties such as cluster dims and threadblock dims. + // When a green context stream is provided, the occupancy query is scoped to the + // green context's SM partition, returning the max active clusters for that partition. static inline int query_device_max_active_clusters( dim3 cluster_dims, uint32_t threads_per_block, - void const* kernel_ptr) { + void const* kernel_ptr, + cudaStream_t stream = nullptr) { int max_active_clusters = 0; #if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) ClusterLauncher::LaunchConfig cluster_launch_config = ClusterLauncher::make_cluster_launch_config( - cluster_dims /* minimum grid dim */, cluster_dims, {threads_per_block, 1, 1}); + cluster_dims /* minimum grid dim */, cluster_dims, {threads_per_block, 1, 1}, + 0 /* smem_size */, stream /* green ctx stream or nullptr */); // Given the kernel function and launch configuration, return the maximum number of clusters that could co-exist on the target device. + // When stream is a green context stream, this returns the max active clusters for that partition. cudaError_t result = cudaOccupancyMaxActiveClusters(&max_active_clusters, kernel_ptr, &cluster_launch_config.launch_config); if (result != cudaSuccess) { CUTLASS_TRACE_HOST( - " cudaGetDevice() returned error " + " cudaOccupancyMaxActiveClusters() returned error " << cudaGetErrorString(result)); return 0; } @@ -108,26 +113,31 @@ struct KernelHardwareInfo { #endif } - // Simpler version of the above query function that fetches relevant information from the Kernel + // Simpler version of the above query function that fetches relevant information from the Kernel. + // When a green context stream is provided, the occupancy query is scoped to that partition. template static inline int - query_device_max_active_clusters() { + query_device_max_active_clusters(cudaStream_t stream = nullptr) { dim3 cluster_dims(cute::size<0>(typename Kernel::ClusterShape{}), cute::size<1>(typename Kernel::ClusterShape{}), cute::size<2>(typename Kernel::ClusterShape{})); uint32_t threads_per_block = Kernel::MaxThreadsPerBlock; void const* kernel_ptr = (void*)(device_kernel); - return query_device_max_active_clusters(cluster_dims, threads_per_block, kernel_ptr); + return query_device_max_active_clusters(cluster_dims, threads_per_block, kernel_ptr, stream); } + // Create a KernelHardwareInfo by querying device properties. + // When a green context stream is provided, max_active_clusters is queried + // against that stream's green context partition instead of the full device. template static inline KernelHardwareInfo - make_kernel_hardware_info(int const device_id = 0, int sm_count = 0, int max_active_clusters = 0) { + make_kernel_hardware_info(int const device_id = 0, int sm_count = 0, int max_active_clusters = 0, + cudaStream_t stream = nullptr) { if (sm_count == 0) { sm_count = query_device_multiprocessor_count(device_id); } if (max_active_clusters == 0) { - max_active_clusters = query_device_max_active_clusters(); + max_active_clusters = query_device_max_active_clusters(stream); } return {device_id, sm_count, max_active_clusters}; } diff --git a/include/cutlass/version.h b/include/cutlass/version.h index f388aa75e..5c30d8c6a 100644 --- a/include/cutlass/version.h +++ b/include/cutlass/version.h @@ -35,8 +35,8 @@ #include #define CUTLASS_MAJOR 4 -#define CUTLASS_MINOR 4 -#define CUTLASS_PATCH 2 +#define CUTLASS_MINOR 5 +#define CUTLASS_PATCH 0 #ifdef CUTLASS_VERSIONS_GENERATED #include "cutlass/version_extended.h" diff --git a/media/docs/pythonDSL/cute_dsl.rst b/media/docs/pythonDSL/cute_dsl.rst index 5ead90f06..0dfbf0656 100644 --- a/media/docs/pythonDSL/cute_dsl.rst +++ b/media/docs/pythonDSL/cute_dsl.rst @@ -13,6 +13,7 @@ CuTe DSL JIT Argument: Layouts JIT Caching JIT Compilation Options + JIT Types Integration with Frameworks Debugging with the DSL Autotuning with the DSL diff --git a/media/docs/pythonDSL/cute_dsl_general/dsl_introduction.rst b/media/docs/pythonDSL/cute_dsl_general/dsl_introduction.rst index 73b7e236f..de1f088ab 100644 --- a/media/docs/pythonDSL/cute_dsl_general/dsl_introduction.rst +++ b/media/docs/pythonDSL/cute_dsl_general/dsl_introduction.rst @@ -82,6 +82,41 @@ Defines GPU kernel functions, compiled as specialized GPU symbols through |DC|. - ``smem`` Specifies the size of shared memory in bytes (integer). + - ``None`` (default) — Automatically calculates the kernel's shared memory usage via **utils.SmemAllocator**. Recommended unless manual control is required. + - ``int`` — Manually specifies the size of shared memory in bytes. + +**Additional Kernel Launch Parameters**: + +- ``fallback_cluster`` + Specifies the minimum-guaranteed cluster size. When set, ``cluster`` becomes the **preferred** size, enabling graceful degradation when hardware cannot satisfy the preferred dimensions. + + - ``None`` (default) — No fallback; ``cluster`` is used directly. + - ``list[int]`` — Three-element list [x, y, z]. + +- ``max_number_threads`` + Specifies the maximum thread count per block (**maxntid**). + + - ``[0, 0, 0]`` (default) — Auto-generate **reqntid** from ``block``. + - ``list[int]`` — Three-element list [x, y, z]. + +- ``min_blocks_per_mp`` + Specifies the minimum blocks per multiprocessor (**minctasm**). + + - ``0`` (default) — No minimum occupancy hint. + - ``int`` — Minimum number of blocks per multiprocessor. + +- ``use_pdl`` + Enables Programmatic Dependent Launch (PDL) to overlap dependent kernel launches in the same stream. + + - ``False`` (default) — PDL disabled. + - ``True`` — PDL enabled. + +- ``cooperative`` + Enables cooperative kernel launch; all thread blocks launch cooperatively with grid-wide synchronization support. + + - ``False`` (default) — Standard kernel launch. + - ``True`` — Cooperative kernel launch. + Calling Conventions ------------------- diff --git a/media/docs/pythonDSL/cute_dsl_general/types.rst b/media/docs/pythonDSL/cute_dsl_general/types.rst new file mode 100644 index 000000000..42452c4fd --- /dev/null +++ b/media/docs/pythonDSL/cute_dsl_general/types.rst @@ -0,0 +1,632 @@ +.. _types: +.. |DSL| replace:: CuTe DSL + +Types +===== + +Overview +-------- + +|DSL| provides a set of core types that form the foundation of tensor layout algebra and GPU programming. These types enable precise control over memory layout, data representation, and tensor operations. This document covers the key types available in ``cutlass.cute.core``. + + +Core Numeric Types +------------------ + +IntValue +~~~~~~~~ + +``IntValue`` is an internal representation of constrained integer types with divisibility information. It serves as a proxy for constrained integer types in the CuTe IR, automatically tracking divisibility constraints that are crucial for layout operations. + +**Key Features:** + +- Inherits from ``ArithValue`` with extensions for divisibility tracking +- Automatically emits ``cute.get_scalars`` operations in the IR +- Supports arithmetic operations that propagate divisibility information +- Used internally for type-safe integer operations in layout algebra + +**API Methods:** + +- ``get_typed_value()`` - Returns the value as an IntTupleType +- ``get_divisibility()`` - Returns the divisibility constraint of the value +- ``divisibility`` - Property that returns the divisibility constraint + +**Supported Operations:** + +The ``IntValue`` type supports standard arithmetic operations with divisibility tracking: + +.. code-block:: python + + # Addition, subtraction, multiplication, division, and modulo + result = int_val1 + int_val2 + result = int_val1 - int_val2 + result = int_val1 * int_val2 + result = int_val1 // int_val2 + result = int_val1 % int_val2 + +**String Representation:** + +.. code-block:: python + + # IntValue with divisibility 1 + str(int_val) # Returns "?" + + # IntValue with divisibility 4 + str(int_val) # Returns "?{div=4}" + + +Ratio +~~~~~ + +``Ratio`` represents a rational number as a ratio of two integers. It is used in CuTe to represent exact fractional values that arise in tensor layout operations, particularly in composition operations where divisibility conditions may not be satisfied. + +**Constructor:** + +.. code-block:: python + + ratio = cute.Ratio(numerator, denominator) + +:param numerator: The numerator of the ratio +:type numerator: int +:param denominator: The denominator of the ratio +:type denominator: int +:raises TypeError: If numerator or denominator are not integers + +**Methods:** + +- ``is_integral()`` - Returns ``True`` if the ratio represents an integer value (numerator divisible by denominator) +- ``reduced()`` - Returns a new Ratio with numerator and denominator reduced to lowest terms +- ``to(dtype)`` - Converts the ratio to another type (Ratio, float, or int) + +**Arithmetic Operations:** + +.. code-block:: python + + # Multiplication with another ratio + ratio1 = cute.Ratio(1, 2) + ratio2 = cute.Ratio(3, 4) + result = ratio1 * ratio2 # Returns Ratio(3, 8) + + # Multiplication with integer + ratio = cute.Ratio(2, 3) + result = ratio * 5 # Returns Ratio(10, 3) + result = 5 * ratio # Returns Ratio(10, 3) + +**Type Conversion:** + +.. code-block:: python + + ratio = cute.Ratio(3, 2) + + # Convert to float + float_val = ratio.to(float) # Returns 1.5 + + # Convert to int (floor division) + int_val = ratio.to(int) # Returns 1 + + +Layout Algebra Types +-------------------- + +ScaledBasis +~~~~~~~~~~~ + +``ScaledBasis`` represents a scaled basis element in CuTe's layout algebra. It consists of a scale value and a mode that identifies which basis element in the layout algebra is being referenced. ScaledBasis elements are fundamental to CuTe's coordinate system representation. + +**Constructor:** + +.. code-block:: python + + sb = cute.ScaledBasis(value, mode) + +:param value: The scale value +:type value: Union[int, Integer, Ratio, ir.Value] +:param mode: The mode identifying the basis element +:type mode: Union[int, List[int]] +:raises TypeError: If mode is not an integer or list of integers + +**Examples:** + +.. code-block:: python + + # Create a scaled basis with integer scale and mode + sb1 = cute.ScaledBasis(2, 0) # 2 * E(0) + + # Create a scaled basis with a Ratio scale + sb2 = cute.ScaledBasis(cute.Ratio(1, 2), 1) # (1/2) * E(1) + + # Create a scaled basis with a list of modes + sb3 = cute.ScaledBasis(4, [0, 1]) # 4 * E([0, 1]) + + # Scaled basis elements are commonly used in layout strides + layout = cute.make_layout((4, 8), stride=(cute.ScaledBasis(2, 0), cute.ScaledBasis(1, 1))) + + # This creates a layout with strides (2@0, 1@1) representing + # a coordinate system where each dimension has its own basis + + # Example: Mapping coordinates to indices using the layout + coord = (2, 3) + idx = cute.crd2idx(coord, layout) # Maps (2, 3) to (4, 3) + +**Properties:** + +- ``value`` - Get the scale value +- ``mode`` - Get the mode as a list of integers +- ``is_static()`` - Returns ``True`` if the value is statically known + +**Methods:** + +- ``to(dtype)`` - Convert to another type (ScaledBasis or internal _ScaledBasis) + +**Operations:** + +.. code-block:: python + + # Right multiplication by a scale factor + sb = cute.ScaledBasis(2, 0) + result = 3 * sb # Creates ScaledBasis(6, 0) + +**Utility Function:** + +.. code-block:: python + + # Create a basis element with unit scale + basis = cute.E(mode) # Equivalent to ScaledBasis(1, mode) + + +Swizzle +~~~~~~~ + +``Swizzle`` is a transformation that permutes the elements of a layout. Swizzles are used to rearrange data elements to improve memory access patterns and computational efficiency, particularly for avoiding bank conflicts in shared memory. + +**Swizzle Parameters:** + +A swizzle is defined by three parameters: + +- **MBase**: The number of least-significant bits to keep constant +- **BBits**: The number of bits in the mask +- **SShift**: The distance to shift the mask + +**Bit Pattern:** + +.. code-block:: text + + 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx + ^--^ MBase (least-sig bits kept constant) + ^-^ ^-^ BBits (number of bits in mask) + ^---------^ SShift (distance to shift YYY) + (positive: right, negative: left) + + Given: 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx + Result: 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx + where AA = ZZ xor YY + +**Usage:** + +Swizzles are typically created using CuTe's swizzle factory functions and composed with layouts to create optimized memory access patterns. + + +Layout +~~~~~~ + +``Layout`` is CuTe's core abstraction for representing tensor layouts. A Layout maps from a logical coordinate space to an index space, defined by a pair of (Shape, Stride). Layouts present a common interface to multidimensional array access that abstracts away the details of how array elements are organized in memory. + +**Key Concepts:** + +- **Shape**: Defines the abstract dimensions of the Layout +- **Stride**: Defines how coordinates within the Shape map to linear indices +- **Hierarchical Structure**: CuTe layouts are inherently hierarchical, constructed from smaller nested layouts + +**Properties:** + +- ``shape`` - An IntTuple representing the dimensions of the layout +- ``stride`` - An IntTuple representing the strides of the layout +- ``max_alignment`` - The maximum alignment of the layout in bytes + +**Examples:** + +.. code-block:: python + + # Creating a layout with shape (4,8) and default stride (column major) + layout = cute.make_layout((4, 8)) + + # Creating a layout with explicit shape and stride (row major) + layout = cute.make_layout((4, 8), stride=(8, 1)) + + # Accessing layout properties + shape = layout.shape # Returns (4, 8) + stride = layout.stride # Returns (8, 1) + + # Mapping a coordinate to an index: (2, 3) -> 2 * 8 + 3 * 1 = 19 + idx = cute.crd2idx((2, 3), layout) + +**Layout Operations:** + +Layouts support a rich algebra of operations: + +- **Concatenation**: Combining layouts along dimensions +- **Coalescence**: Merging adjacent modes +- **Composition**: Composing layouts with functions or other layouts +- **Complement**: Computing the complement space +- **Inversion**: Inverting the layout mapping + +**String Representation:** + +.. code-block:: python + + layout = cute.make_layout((4, 8), stride=(1, 4)) + print(layout) # Prints "shape:stride" format, e.g., "(4,8):(1,4)" + + +ComposedLayout +~~~~~~~~~~~~~~ + +``ComposedLayout`` represents a composition of layouts and transformations. It is a generalization of normal layouts that can support arbitrary function mappings from coordinate to coordinate as an inner layout. + +**Structure:** + +A ComposedLayout consists of three components: + +- **inner**: The inner transformation (Swizzle or Layout) +- **offset**: An offset applied to coordinates +- **outer**: The outer layout + +**Properties:** + +- ``inner`` - Returns the inner transformation (Union[Swizzle, Layout]) +- ``offset`` - Returns the offset as an IntTuple +- ``outer`` - Returns the outer layout +- ``shape`` - Returns the shape of the composed layout +- ``max_alignment`` - Returns the maximum alignment +- ``is_normal`` - Returns ``True`` if this is a normal layout (not a general composition) + +**Examples:** + +.. code-block:: python + + # ComposedLayouts are typically created through composition operations + # For example, composing a layout with a swizzle + layout = cute.make_layout((8, 8)) + swizzle = cute.make_swizzle(...) + composed = cute.composition(swizzle, layout) + + # Accessing components + inner = composed.inner # Returns the swizzle + outer = composed.outer # Returns the layout + offset = composed.offset # Returns the offset + +**String Representation:** + +.. code-block:: python + + print(composed) # Prints "inner o offset o outer" format + + +Memory and Pointer Types +------------------------- + +Pointer +~~~~~~~ + +``Pointer`` represents a memory address with specific properties. Pointers are a fundamental type of iterator/engine that support random-access operations. They can be offset by elements of a layout's codomain and dereferenced to produce values. + +**Properties:** + +- ``dtype`` - The type of value this pointer points to +- ``type`` - The MLIR type of the pointer +- ``memspace`` - The memory space where the pointer data resides (e.g., ``gmem``, ``smem``, ``rmem``) +- ``alignment`` - The alignment of the pointer in bytes +- ``max_alignment`` - The maximum alignment of the pointer in bytes + +**Operations:** + +.. code-block:: python + + # Pointer arithmetic + ptr2 = ptr + offset # Offset pointer forward + ptr3 = offset + ptr # Offset pointer forward (commutative) + ptr4 = ptr - offset # Offset pointer backward + + # Convert pointer to integer + int_addr = ptr.toint() + + # Align pointer to specified byte boundary + aligned_ptr = ptr.align(16) # Align to 16-byte boundary + +**Tensor Composition:** + +When composed with a layout, a pointer forms a tensor: ``T = E ∘ L``, where ``E`` is the pointer (engine) and ``L`` is the layout. The tensor evaluates the layout by mapping a coordinate ``c`` to the codomain, offsets the pointer accordingly, and dereferences the result: + +.. code-block:: text + + T(c) = (E ∘ L)(c) = *(E + L(c)) + +**Methods:** + +- ``llvm_ptr`` - Get the LLVM pointer representation (low-level use only) +- ``align(min_align)`` - Align pointer to specified byte alignment (must be power of 2) +- ``toint()`` - Convert pointer to integer address (Int64 for gmem/generic, Int32 otherwise) + +**Examples:** + +.. code-block:: python + + # Create a pointer from a tensor's data + ptr = tensor.data() + + # Offset the pointer + offset_ptr = ptr + 16 + + # Check pointer properties + print(f"Memory space: {ptr.memspace}") + print(f"Alignment: {ptr.alignment}") + print(f"Data type: {ptr.dtype}") + + +Structured Data Types +--------------------- + +struct +~~~~~~ + +The ``struct`` decorator abstracts C structures in Python DSL. It allows you to define structured data types with precise control over layout, alignment, and nesting. + +**Supported Elements:** + +- Base DSL scalar int/float elements +- Arrays (MemRange) +- Nested structures +- Aligned elements + +**Basic Usage:** + +.. code-block:: python + + # Define a simple struct + @cute.struct + class complex: + real : cutlass.Float32 + imag : cutlass.Float32 + + # Define a struct with arrays and nested structures + @cute.struct + class StorageA: + mbarA : cute.struct.MemRange[cutlass.Int64, stage] + compA : complex + intA : cutlass.Int16 + +**Alignment Control:** + +.. code-block:: python + + # Define a struct with explicit alignment + @cute.struct + class StorageB: + a: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, size_a], 1024 + ] + b: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, size_b], 1024 + ] + x: cute.struct.Align[cutlass.Int32, 16] + compA: cute.struct.Align[complex, 16] + +**Static Queries:** + +.. code-block:: python + + # Get size and alignment at compile time + size = StorageB.__sizeof__() + align = StorageB.__alignof__() + +**Allocation and Access:** + +.. code-block:: python + + # Allocate and reference elements + storage = allocator.allocate(StorageB) + + # Access struct members + storage.a[0] = ... + storage.x = ... + ... = storage.compA.real.ptr + ... = storage.x.ptr.load() + +**Methods:** + +- ``__sizeof__()`` - Returns the size of the struct in bytes +- ``__alignof__()`` - Returns the alignment of the struct in bytes +- ``size_in_bytes()`` - Returns the size of the struct in bytes + + +struct.MemRange +^^^^^^^^^^^^^^^ + +``MemRange`` defines a contiguous range of memory with a specific element type and size. + +**Syntax:** + +.. code-block:: python + + cute.struct.MemRange[dtype, size] + +:param dtype: The data type (must be a DSL scalar type) +:type dtype: Type[Numeric] +:param size: The number of elements in the range +:type size: int + +**Properties:** + +- ``size`` - Number of elements in the range +- ``elem_width`` - Width of each element in bits +- ``size_in_bytes`` - Total size in bytes + +**Methods:** + +- ``data_ptr()`` - Returns a pointer to the start of the memory range +- ``get_tensor(layout, swizzle=None, dtype=None)`` - Creates a tensor from the memory range +- ``__getitem__(index)`` - Returns the element at the specified index + +**Examples:** + +.. code-block:: python + + @cute.struct + class Buffer: + data : cute.struct.MemRange[cutlass.Float32, 128] + + # Allocate buffer + buf = allocator.allocate(Buffer) + + # Get pointer to data + ptr = buf.data.data_ptr() + + # Access individual elements + element = buf.data[5] + + # Create tensor from memory range + layout = cute.make_layout((8, 16)) + tensor = buf.data.get_tensor(layout) + + +struct.Align +^^^^^^^^^^^^ + +``Align`` specifies explicit alignment requirements for struct members. + +**Syntax:** + +.. code-block:: python + + cute.struct.Align[dtype, alignment] + +:param dtype: The type to align (scalar, MemRange, or struct) +:type dtype: Type +:param alignment: The alignment in bytes (must be > 0) +:type alignment: int + +**Properties:** + +- ``dtype`` - The data type being aligned +- ``align`` - The alignment value + +**Examples:** + +.. code-block:: python + + @cute.struct + class AlignedStorage: + # Align scalar to 16 bytes + counter: cute.struct.Align[cutlass.Int32, 16] + + # Align array to 1024 bytes + buffer: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, 256], 1024 + ] + + +union +~~~~~ + +The ``union`` decorator abstracts C unions in Python DSL. Similar to ``struct``, but all members start at offset 0, and the size is the maximum size of all members. + +**Layout Characteristics:** + +- All objects start at offset 0 +- Alignment is the maximum alignment of all objects +- Size is the maximum size of all objects + +**Usage:** + +.. code-block:: python + + # Define a union with scalar elements + @cute.union + class value_union: + as_int : cutlass.Int32 + as_float : cutlass.Float32 + + # Allocate union + val = allocator.allocate(value_union) + + # Access different interpretations of same memory + val.as_int = 42 + float_val = val.as_float.ptr.load() # Interpret same bits as float + +**Methods:** + +Same as ``struct``: + +- ``__sizeof__()`` - Returns the size of the union in bytes +- ``__alignof__()`` - Returns the alignment of the union in bytes + + +Deprecated Types +---------------- + +ThrMma +~~~~~~ + +.. deprecated:: + ``cute.core.ThrMma`` is deprecated, use ``cute.ThrMma`` instead + +ThrCopy +~~~~~~~ + +.. deprecated:: + ``cute.core.ThrCopy`` is deprecated, use ``cute.ThrCopy`` instead + + +Type Hierarchies and Relationships +----------------------------------- + +**Type Protocol Support:** + +Many CuTe types implement standard Python protocols for integration: + +- ``__str__()`` - String representation for debugging +- ``__eq__()`` / ``__ne__()`` - Equality comparison +- ``__getitem__()`` - Indexing operations +- ``__add__()`` / ``__sub__()`` / ``__mul__()`` / ``__floordiv__()`` / ``__mod__()`` - Arithmetic + +**MLIR Integration:** + +Internal types like ``IntValue``, ``Layout``, ``Pointer``, and ``ComposedLayout`` are registered as MLIR value casters, enabling seamless integration with the underlying compiler infrastructure. + + +Best Practices +-------------- + +**Choosing Between Static and Dynamic:** + +- Use static values (Python ``int``) when dimensions are known at compile time for maximum optimization +- Use dynamic values (``IntValue``) when dimensions must be determined at runtime +- Refer to :doc:`dsl_dynamic_layout` for detailed guidance on static vs dynamic layouts + +**Memory Alignment:** + +- Always specify alignment requirements for shared memory structures to avoid bank conflicts +- Use ``struct.Align`` to enforce alignment constraints +- Check ``max_alignment`` properties to verify pointer and layout alignment + +**Layout Operations:** + +- Prefer built-in layout operations (``make_layout``, ``composition``, etc.) over manual construction +- Use ``ScaledBasis`` for explicit control over stride modes in multi-modal layouts +- Leverage ``ComposedLayout`` for complex transformations like swizzling + +**Type Safety:** + +- Use type annotations in ``@jit`` and ``@kernel`` functions +- Let the DSL infer types when possible for cleaner code +- Check ``dtype`` and ``memspace`` properties when working with pointers + + +See Also +-------- + +- :doc:`dsl_introduction` - Introduction to CuTe DSL decorators and calling conventions +- :doc:`dsl_control_flow` - Control flow with static and dynamic values +- :doc:`dsl_dynamic_layout` - Working with static and dynamic layouts +- :doc:`framework_integration` - Integration with deep learning frameworks +- :doc:`debugging` - Debugging techniques for CuTe DSL programs diff --git a/media/docs/pythonDSL/quick_start.rst b/media/docs/pythonDSL/quick_start.rst index e97e21a76..e6392c78b 100644 --- a/media/docs/pythonDSL/quick_start.rst +++ b/media/docs/pythonDSL/quick_start.rst @@ -3,7 +3,7 @@ Quick Start Guide ======================= -The CUTLASS DSL 4.4 release currently supports **Linux** and **Python 3.10 - 3.13** only. To install CUTLASS DSLs (limited to CuTe DSL for now), use the following command +The CUTLASS DSL 4.4 release currently supports **Linux** and **Python 3.10 - 3.14** only. To install CUTLASS DSLs (limited to CuTe DSL for now), use the following command Installation ----------------------- diff --git a/python/CuTeDSL/cutlass/__init__.py b/python/CuTeDSL/cutlass/__init__.py index 8836db845..3d00944bf 100644 --- a/python/CuTeDSL/cutlass/__init__.py +++ b/python/CuTeDSL/cutlass/__init__.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/__init__.py b/python/CuTeDSL/cutlass/base_dsl/__init__.py index 567d4da7a..32a251528 100644 --- a/python/CuTeDSL/cutlass/base_dsl/__init__.py +++ b/python/CuTeDSL/cutlass/base_dsl/__init__.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/__init__.py b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/__init__.py index 3989ba8fe..db064f198 100644 --- a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/__init__.py +++ b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/__init__.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/arith.py b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/arith.py index 010701400..b02c79d72 100644 --- a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/arith.py +++ b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/arith.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/gpu.py b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/gpu.py index 2df2874a0..b21e44e9d 100644 --- a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/gpu.py +++ b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/gpu.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/lru_cache_ir.py b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/lru_cache_ir.py index e9d5e2793..eed81e192 100644 --- a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/lru_cache_ir.py +++ b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/lru_cache_ir.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/op.py b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/op.py index ad7893ee4..4ef4c3c75 100644 --- a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/op.py +++ b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/op.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA @@ -14,10 +14,48 @@ This module provides MLIR's OP helper functions """ import inspect +import os +import types from functools import wraps from ..._mlir import ir from ..common import DSLRuntimeError +from ..utils.stacktrace import walk_to_top_module + + +# The DSL package root is empty by default. +_DSL_PACKAGE_ROOT = "" + + +def _is_framework_frame(filename: str) -> bool: + """Check if a frame's filename belongs to DSL library code.""" + global _DSL_PACKAGE_ROOT + if _DSL_PACKAGE_ROOT == "": + # Compute the DSL package root once + # Any frame whose file starts with this prefix is considered DSL library code. + _DSL_PACKAGE_ROOT = walk_to_top_module( + os.path.dirname(os.path.abspath(__file__)) + ) + + if _DSL_PACKAGE_ROOT is None: + return False + + return os.path.abspath(filename).startswith(_DSL_PACKAGE_ROOT) + + +def _find_user_frame(start_frame: types.FrameType | None) -> types.FrameType | None: + """Walk up the call stack from start_frame to find the first user (non-library) frame. + + Returns the first frame whose file is not under the DSL package root. + Falls back to start_frame if no user frame is found (e.g. all frames are library code). + """ + frame = start_frame + while frame is not None: + if not _is_framework_frame(frame.f_code.co_filename): + return frame + frame = frame.f_back + # Fallback: if everything is framework code, use the original caller + return start_frame def dsl_user_op(opFunc): @@ -34,30 +72,39 @@ def dsl_user_op(opFunc): @wraps(opFunc) def wrapper(*args, **kwargs): loc = kwargs.pop("loc", None) - if loc is None: - frame = inspect.currentframe().f_back + frameInfo = None + verifier_error = False + + if loc is None and ir.Context.current is not None: + frame = _find_user_frame(inspect.currentframe().f_back) frameInfo = inspect.getframeinfo(frame) - # In Python < 3.11, getframeinfo returns a NamedTuple without positions - if not hasattr(frameInfo, "positions"): - file_loc = ir.Location.file( - frameInfo.filename, - frameInfo.lineno, - 0, + try: + # In Python < 3.11, getframeinfo returns a NamedTuple without positions + if not hasattr(frameInfo, "positions"): + file_loc = ir.Location.file( + frameInfo.filename, + frameInfo.lineno, + 0, + ) + else: + file_loc = ir.Location.file( + frameInfo.filename, + frameInfo.positions.lineno, + frameInfo.positions.col_offset or 0, + ) + loc = ir.Location.name( + ( + "".join([c.strip() for c in frameInfo.code_context]) + if frameInfo.code_context + else frameInfo.function + ), + childLoc=file_loc, ) - else: - file_loc = ir.Location.file( - frameInfo.filename, - frameInfo.positions.lineno, - frameInfo.positions.col_offset, - ) - loc = ir.Location.name( - ( - "".join([c.strip() for c in frameInfo.code_context]) - if frameInfo.code_context - else frameInfo.function - ), - childLoc=file_loc, - ) + except RuntimeError: + # No MLIR context available (e.g. validation-only call + # outside a kernel). Proceed with loc=None so that the + # wrapped function's own validation can still fire. + pass try: res_or_list = opFunc(*args, **kwargs, loc=loc) diff --git a/python/CuTeDSL/cutlass/base_dsl/arch.py b/python/CuTeDSL/cutlass/base_dsl/arch.py index 76b75d191..070a4d2fb 100644 --- a/python/CuTeDSL/cutlass/base_dsl/arch.py +++ b/python/CuTeDSL/cutlass/base_dsl/arch.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/ast_helpers.py b/python/CuTeDSL/cutlass/base_dsl/ast_helpers.py index cda5137d7..56a46004c 100644 --- a/python/CuTeDSL/cutlass/base_dsl/ast_helpers.py +++ b/python/CuTeDSL/cutlass/base_dsl/ast_helpers.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/ast_preprocessor.py b/python/CuTeDSL/cutlass/base_dsl/ast_preprocessor.py index 5eea30185..b66744634 100644 --- a/python/CuTeDSL/cutlass/base_dsl/ast_preprocessor.py +++ b/python/CuTeDSL/cutlass/base_dsl/ast_preprocessor.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/cache_helpers.py b/python/CuTeDSL/cutlass/base_dsl/cache_helpers.py index ce920311a..167c94b32 100644 --- a/python/CuTeDSL/cutlass/base_dsl/cache_helpers.py +++ b/python/CuTeDSL/cutlass/base_dsl/cache_helpers.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/common.py b/python/CuTeDSL/cutlass/base_dsl/common.py index be85fd994..eb0685456 100644 --- a/python/CuTeDSL/cutlass/base_dsl/common.py +++ b/python/CuTeDSL/cutlass/base_dsl/common.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/compiler.py b/python/CuTeDSL/cutlass/base_dsl/compiler.py index 10afaa4ca..41dbb957c 100644 --- a/python/CuTeDSL/cutlass/base_dsl/compiler.py +++ b/python/CuTeDSL/cutlass/base_dsl/compiler.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/dsl.py b/python/CuTeDSL/cutlass/base_dsl/dsl.py index 0305ba303..4b9c50a4a 100644 --- a/python/CuTeDSL/cutlass/base_dsl/dsl.py +++ b/python/CuTeDSL/cutlass/base_dsl/dsl.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA @@ -1341,7 +1341,11 @@ class BaseDSL(metaclass=DSLSingletonMeta): location=None, ): """Generate MLIR module and compile iself.T_provider.""" - with ir.Context(), self.get_ir_location(location): + with ir.Context() as ctx, self.get_ir_location(location): + # If threading is enabled, each MLIR context will keep alive a thread pool. + # When we cache MLIR compilation results, we also cache its context thus accumulating #(compilations) * thread_pool_size threads. + # Disable threading to avoid such excessive number of threads. + ctx.enable_multithreading(False) try: # Convert input arguments to MLIR arguments exe_args, func_types, adapted_args = self.generate_mlir_function_types( @@ -1491,6 +1495,11 @@ class BaseDSL(metaclass=DSLSingletonMeta): # Check if all non-default arguments are provided for param in sig.parameters.values(): + if param.kind in ( + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + ): + continue if ( param.default is inspect.Parameter.empty and param.name not in bound_args.arguments @@ -1501,6 +1510,95 @@ class BaseDSL(metaclass=DSLSingletonMeta): return sig + def _get_full_arg_spec(self, funcBody): + """ + Returns the full argument specification for a given function, handling PEP-563 + (postponed evaluation of type annotations) if necessary. + + If the function's annotations are provided as strings (which occurs when PEP-563 + is enabled), this method evaluates those annotations so they are returned as objects + instead of strings. + + Parameters + ---------- + funcBody : function + The function whose argument specification is to be retrieved. + + Returns + ------- + inspect.FullArgSpec + The complete argument specification of the function, with its annotations + properly evaluated and resolved where relevant. + """ + args_spec = inspect.getfullargspec(funcBody) + # Set `eval_str = True` to make it work when PEP-563 is enabled + if args_spec.annotations and all( + type(arg_type) is str for arg_type in args_spec.annotations.values() + ): + eval_annotations = inspect.get_annotations(funcBody, eval_str=True) + args_spec = inspect.FullArgSpec( + args_spec.args, + args_spec.varargs, + args_spec.varkw, + args_spec.defaults, + args_spec.kwonlyargs, + args_spec.kwonlydefaults, + eval_annotations, + ) + return args_spec + + @staticmethod + def _expand_varargs_varkw( + canonicalized_args: tuple, + canonicalized_kwargs: dict, + args_spec: inspect.FullArgSpec, + ) -> inspect.FullArgSpec: + """Expand *args and **kwargs into concrete named parameters in the FullArgSpec. + + When a JIT function uses *args or **kwargs, the concrete call-site values + are known. This method synthesizes named parameters for them so the rest + of the pipeline (which expects fixed-arity signatures) works unchanged. + + For *args: extra positional arguments beyond ``args_spec.args`` get + synthetic names ``_vararg_0``, ``_vararg_1``, etc. + + For **kwargs: extra keyword arguments beyond ``args_spec.kwonlyargs`` + are appended as keyword-only parameters. + """ + if not args_spec.varargs and not args_spec.varkw: + return args_spec + + expanded_args = list(args_spec.args) + expanded_annotations = dict(args_spec.annotations) + expanded_defaults = list(args_spec.defaults) if args_spec.defaults else [] + + if args_spec.varargs: + n_regular = len(args_spec.args) + n_extra = len(canonicalized_args) - n_regular + for i in range(n_extra): + expanded_args.append(f"varargs_{i}") + + expanded_kwonlyargs = list(args_spec.kwonlyargs) + expanded_kwonlydefaults = ( + dict(args_spec.kwonlydefaults) if args_spec.kwonlydefaults else {} + ) + + if args_spec.varkw: + existing_kwonly = set(args_spec.kwonlyargs) + for key in canonicalized_kwargs: + if key not in existing_kwonly: + expanded_kwonlyargs.append(key) + + return inspect.FullArgSpec( + args=expanded_args, + varargs=None, + varkw=None, + defaults=tuple(expanded_defaults) if expanded_defaults else None, + kwonlyargs=expanded_kwonlyargs, + kwonlydefaults=expanded_kwonlydefaults if expanded_kwonlydefaults else None, + annotations=expanded_annotations, + ) + def _func(self, funcBody, *args, **kwargs): """Decorator for MLIR functions. It cuts the boilerplate code, does the following: @@ -1553,6 +1651,10 @@ class BaseDSL(metaclass=DSLSingletonMeta): canonicalized_args, canonicalized_kwargs = self._canonicalize_args( sig, *args, **kwargs ) + # Expand *args/**kwargs into concrete named parameters + args_spec = self._expand_varargs_varkw( + canonicalized_args, canonicalized_kwargs, args_spec + ) # Simple name mangling function_name = self.mangle_name(function_name, canonicalized_args, args_spec) if func_name_prefix: diff --git a/python/CuTeDSL/cutlass/base_dsl/env_manager.py b/python/CuTeDSL/cutlass/base_dsl/env_manager.py index 19c663221..414227731 100644 --- a/python/CuTeDSL/cutlass/base_dsl/env_manager.py +++ b/python/CuTeDSL/cutlass/base_dsl/env_manager.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/export/__init__.py b/python/CuTeDSL/cutlass/base_dsl/export/__init__.py index 66eaf4031..d06596efe 100644 --- a/python/CuTeDSL/cutlass/base_dsl/export/__init__.py +++ b/python/CuTeDSL/cutlass/base_dsl/export/__init__.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/export/c_header_generator.py b/python/CuTeDSL/cutlass/base_dsl/export/c_header_generator.py index f99a3e36c..6a07c2b88 100644 --- a/python/CuTeDSL/cutlass/base_dsl/export/c_header_generator.py +++ b/python/CuTeDSL/cutlass/base_dsl/export/c_header_generator.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/export/export.py b/python/CuTeDSL/cutlass/base_dsl/export/export.py index 7e91ab005..9a4bd4b18 100644 --- a/python/CuTeDSL/cutlass/base_dsl/export/export.py +++ b/python/CuTeDSL/cutlass/base_dsl/export/export.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/export/external_binary_module.py b/python/CuTeDSL/cutlass/base_dsl/export/external_binary_module.py index acd131ae4..7119d5e0e 100644 --- a/python/CuTeDSL/cutlass/base_dsl/export/external_binary_module.py +++ b/python/CuTeDSL/cutlass/base_dsl/export/external_binary_module.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/jit_executor.py b/python/CuTeDSL/cutlass/base_dsl/jit_executor.py index 236049a6a..b702493f6 100644 --- a/python/CuTeDSL/cutlass/base_dsl/jit_executor.py +++ b/python/CuTeDSL/cutlass/base_dsl/jit_executor.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/runtime/__init__.py b/python/CuTeDSL/cutlass/base_dsl/runtime/__init__.py index 9af7e01f8..9706e0f29 100644 --- a/python/CuTeDSL/cutlass/base_dsl/runtime/__init__.py +++ b/python/CuTeDSL/cutlass/base_dsl/runtime/__init__.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/runtime/cuda.py b/python/CuTeDSL/cutlass/base_dsl/runtime/cuda.py index e7a60a3db..55a23d9f8 100644 --- a/python/CuTeDSL/cutlass/base_dsl/runtime/cuda.py +++ b/python/CuTeDSL/cutlass/base_dsl/runtime/cuda.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/runtime/device_tensor.py b/python/CuTeDSL/cutlass/base_dsl/runtime/device_tensor.py index 7a3619761..ef89dbbee 100644 --- a/python/CuTeDSL/cutlass/base_dsl/runtime/device_tensor.py +++ b/python/CuTeDSL/cutlass/base_dsl/runtime/device_tensor.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/runtime/dlpack_types.py b/python/CuTeDSL/cutlass/base_dsl/runtime/dlpack_types.py index 845bb01ca..b35d798c3 100644 --- a/python/CuTeDSL/cutlass/base_dsl/runtime/dlpack_types.py +++ b/python/CuTeDSL/cutlass/base_dsl/runtime/dlpack_types.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/runtime/jit_arg_adapters.py b/python/CuTeDSL/cutlass/base_dsl/runtime/jit_arg_adapters.py index 0b27f959f..84952c549 100644 --- a/python/CuTeDSL/cutlass/base_dsl/runtime/jit_arg_adapters.py +++ b/python/CuTeDSL/cutlass/base_dsl/runtime/jit_arg_adapters.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/runtime/stream_adapter.py b/python/CuTeDSL/cutlass/base_dsl/runtime/stream_adapter.py index e9df4c6a8..2d43fd3ce 100644 --- a/python/CuTeDSL/cutlass/base_dsl/runtime/stream_adapter.py +++ b/python/CuTeDSL/cutlass/base_dsl/runtime/stream_adapter.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/runtime/tensor_descriptor.py b/python/CuTeDSL/cutlass/base_dsl/runtime/tensor_descriptor.py index 5b9469503..4d311dacc 100644 --- a/python/CuTeDSL/cutlass/base_dsl/runtime/tensor_descriptor.py +++ b/python/CuTeDSL/cutlass/base_dsl/runtime/tensor_descriptor.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/__init__.py b/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/__init__.py index 845ae0763..3fd297e9a 100644 --- a/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/__init__.py +++ b/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/__init__.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/call_provider.py b/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/call_provider.py index 0ce034754..6b47e8de7 100644 --- a/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/call_provider.py +++ b/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/call_provider.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/mlir_builder.py b/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/mlir_builder.py index 976814289..56f56cd41 100644 --- a/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/mlir_builder.py +++ b/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/mlir_builder.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/spec.py b/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/spec.py index 13c248359..1232c1780 100644 --- a/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/spec.py +++ b/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/spec.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/tvm_ffi_builder.py b/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/tvm_ffi_builder.py index 5b13b2137..b8b0ebe49 100644 --- a/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/tvm_ffi_builder.py +++ b/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/tvm_ffi_builder.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/typing.py b/python/CuTeDSL/cutlass/base_dsl/typing.py index 0ef742a5d..b299b6dde 100644 --- a/python/CuTeDSL/cutlass/base_dsl/typing.py +++ b/python/CuTeDSL/cutlass/base_dsl/typing.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA @@ -35,6 +35,7 @@ from .common import * from .ast_helpers import const_expr from ._mlir_helpers import arith as arith_helper, lru_cache_ir from ._mlir_helpers.arith import ArithValue +from ._mlir_helpers.op import dsl_user_op from .._mlir import ir from .._mlir.extras import types as T @@ -843,7 +844,6 @@ def _binary_op(op, promote_operand=True, promote_bool=False, flip=False): if flip: lhs_val, rhs_val = rhs_val, lhs_val - # Check if the operation is supported by the operands res_val = op(lhs_val, rhs_val) return res_type(res_val, loc=loc, ip=ip) @@ -1152,72 +1152,91 @@ class Numeric(metaclass=NumericMeta, is_abstract=True): ) return res_type(value) + @dsl_user_op def __add__(self, other, *, loc=None, ip=None) -> "Numeric": return _binary_op(operator.add, promote_bool=True)(self, other, loc=loc, ip=ip) + @dsl_user_op def __sub__(self, other, *, loc=None, ip=None) -> "Numeric": return _binary_op(operator.sub, promote_bool=True)(self, other, loc=loc, ip=ip) + @dsl_user_op def __mul__(self, other, *, loc=None, ip=None) -> "Numeric": return _binary_op(operator.mul, promote_bool=True)(self, other, loc=loc, ip=ip) + @dsl_user_op def __floordiv__(self, other, *, loc=None, ip=None) -> "Numeric": return _binary_op(operator.floordiv, promote_bool=True)( self, other, loc=loc, ip=ip ) + @dsl_user_op def __truediv__(self, other, *, loc=None, ip=None) -> "Numeric": return _binary_op(operator.truediv, promote_bool=True)( self, other, loc=loc, ip=ip ) + @dsl_user_op def __mod__(self, other, *, loc=None, ip=None) -> "Numeric": return _binary_op(operator.mod, promote_bool=True)(self, other, loc=loc, ip=ip) + @dsl_user_op def __radd__(self, other, *, loc=None, ip=None) -> "Numeric": return self.__add__(other, loc=loc, ip=ip) + @dsl_user_op def __rsub__(self, other, *, loc=None, ip=None) -> "Numeric": return _binary_op(operator.sub, promote_bool=True, flip=True)( self, other, loc=loc, ip=ip ) + @dsl_user_op def __rmul__(self, other, *, loc=None, ip=None) -> "Numeric": return self.__mul__(other, loc=loc, ip=ip) + @dsl_user_op def __rfloordiv__(self, other, *, loc=None, ip=None) -> "Numeric": return _binary_op(operator.floordiv, promote_bool=True, flip=True)( self, other, loc=loc, ip=ip ) + @dsl_user_op def __rtruediv__(self, other, *, loc=None, ip=None) -> "Numeric": return _binary_op(operator.truediv, promote_bool=True, flip=True)( self, other, loc=loc, ip=ip ) + @dsl_user_op def __rmod__(self, other, *, loc=None, ip=None) -> "Numeric": return _binary_op(operator.mod, promote_bool=True, flip=True)( self, other, loc=loc, ip=ip ) + @dsl_user_op def __eq__(self, other, *, loc=None, ip=None) -> "Boolean": return _binary_op(operator.eq)(self, other, loc=loc, ip=ip) # type: ignore + @dsl_user_op def __ne__(self, other, *, loc=None, ip=None) -> "Boolean": return _binary_op(operator.ne)(self, other, loc=loc, ip=ip) # type: ignore + @dsl_user_op def __lt__(self, other, *, loc=None, ip=None) -> "Boolean": return _binary_op(operator.lt)(self, other, loc=loc, ip=ip) # type: ignore + @dsl_user_op def __le__(self, other, *, loc=None, ip=None) -> "Boolean": return _binary_op(operator.le)(self, other, loc=loc, ip=ip) # type: ignore + @dsl_user_op def __gt__(self, other, *, loc=None, ip=None) -> "Boolean": return _binary_op(operator.gt)(self, other, loc=loc, ip=ip) # type: ignore + @dsl_user_op def __ge__(self, other, *, loc=None, ip=None) -> "Boolean": return _binary_op(operator.ge)(self, other, loc=loc, ip=ip) # type: ignore + @dsl_user_op def __pow__(self, other, *, loc=None, ip=None) -> "Numeric": return _binary_op(operator.pow)(self, other, loc=loc, ip=ip) # type: ignore diff --git a/python/CuTeDSL/cutlass/base_dsl/utils/__init__.py b/python/CuTeDSL/cutlass/base_dsl/utils/__init__.py index 479840202..b2458957d 100644 --- a/python/CuTeDSL/cutlass/base_dsl/utils/__init__.py +++ b/python/CuTeDSL/cutlass/base_dsl/utils/__init__.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/utils/logger.py b/python/CuTeDSL/cutlass/base_dsl/utils/logger.py index d47b3c8e4..2ce47c7cb 100644 --- a/python/CuTeDSL/cutlass/base_dsl/utils/logger.py +++ b/python/CuTeDSL/cutlass/base_dsl/utils/logger.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/utils/stacktrace.py b/python/CuTeDSL/cutlass/base_dsl/utils/stacktrace.py index 00d137cba..af899faa9 100644 --- a/python/CuTeDSL/cutlass/base_dsl/utils/stacktrace.py +++ b/python/CuTeDSL/cutlass/base_dsl/utils/stacktrace.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/utils/timer.py b/python/CuTeDSL/cutlass/base_dsl/utils/timer.py index 50d2c7f09..bf4a2ce0b 100644 --- a/python/CuTeDSL/cutlass/base_dsl/utils/timer.py +++ b/python/CuTeDSL/cutlass/base_dsl/utils/timer.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/utils/tree_utils.py b/python/CuTeDSL/cutlass/base_dsl/utils/tree_utils.py index 5939d2aa8..6e1e59f08 100644 --- a/python/CuTeDSL/cutlass/base_dsl/utils/tree_utils.py +++ b/python/CuTeDSL/cutlass/base_dsl/utils/tree_utils.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/base_dsl/version_info.py b/python/CuTeDSL/cutlass/base_dsl/version_info.py index 203ddfb7a..122ee1933 100644 --- a/python/CuTeDSL/cutlass/base_dsl/version_info.py +++ b/python/CuTeDSL/cutlass/base_dsl/version_info.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/__init__.py b/python/CuTeDSL/cutlass/cute/__init__.py index 6f082f17f..94ac506f8 100644 --- a/python/CuTeDSL/cutlass/cute/__init__.py +++ b/python/CuTeDSL/cutlass/cute/__init__.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA @@ -105,6 +105,7 @@ from .core import ( E, # User defined struct struct, + union, pretty_str, make_layout_image_mask, repeat, diff --git a/python/CuTeDSL/cutlass/cute/_tvm_ffi_args_spec_converter.py b/python/CuTeDSL/cutlass/cute/_tvm_ffi_args_spec_converter.py index 8a5ebc9ee..e878d6e4a 100644 --- a/python/CuTeDSL/cutlass/cute/_tvm_ffi_args_spec_converter.py +++ b/python/CuTeDSL/cutlass/cute/_tvm_ffi_args_spec_converter.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/algorithm.py b/python/CuTeDSL/cutlass/cute/algorithm.py index bc03231c3..f982c7b44 100644 --- a/python/CuTeDSL/cutlass/cute/algorithm.py +++ b/python/CuTeDSL/cutlass/cute/algorithm.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/arch/__init__.py b/python/CuTeDSL/cutlass/cute/arch/__init__.py index 0cda0a961..776398787 100644 --- a/python/CuTeDSL/cutlass/cute/arch/__init__.py +++ b/python/CuTeDSL/cutlass/cute/arch/__init__.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA @@ -50,6 +50,7 @@ __all__ = [ "block_in_cluster_idx", "block_in_cluster_dim", "block_idx_in_cluster", + "dynamic_smem_size", "shuffle_sync", "shuffle_sync_up", "shuffle_sync_down", diff --git a/python/CuTeDSL/cutlass/cute/arch/clc.py b/python/CuTeDSL/cutlass/cute/arch/clc.py index af344b5a1..e99c8d54d 100644 --- a/python/CuTeDSL/cutlass/cute/arch/clc.py +++ b/python/CuTeDSL/cutlass/cute/arch/clc.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/arch/elect.py b/python/CuTeDSL/cutlass/cute/arch/elect.py index abd213b18..d9a26db84 100644 --- a/python/CuTeDSL/cutlass/cute/arch/elect.py +++ b/python/CuTeDSL/cutlass/cute/arch/elect.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/arch/mbar.py b/python/CuTeDSL/cutlass/cute/arch/mbar.py index 17e541e8e..ca04dcc34 100644 --- a/python/CuTeDSL/cutlass/cute/arch/mbar.py +++ b/python/CuTeDSL/cutlass/cute/arch/mbar.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA @@ -35,7 +35,10 @@ def mbarrier_init(mbar_ptr: Pointer, cnt: Int, *, loc=None, ip=None) -> None: :type cnt: Int """ nvvm.mbarrier_init_shared( - mbar_ptr.llvm_ptr, Int32(cnt).ir_value(loc=loc, ip=ip), loc=loc, ip=ip + mbar_ptr.to_llvm_ptr(loc=loc, ip=ip), + Int32(cnt).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, ) @@ -65,7 +68,7 @@ def mbarrier_arrive_and_expect_tx( """ BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90) - mbar_llvm_ptr = mbar_ptr.llvm_ptr + mbar_llvm_ptr = mbar_ptr.to_llvm_ptr(loc=loc, ip=ip) if peer_cta_rank_in_cluster is not None: mbar_cluster_type = llvm.PointerType.get(AddressSpace.dsmem) mbar_llvm_ptr = nvvm.mapa( @@ -108,7 +111,7 @@ def mbarrier_expect_tx( """ BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90) - mbar_llvm_ptr = mbar_ptr.llvm_ptr + mbar_llvm_ptr = mbar_ptr.to_llvm_ptr(loc=loc, ip=ip) if peer_cta_rank_in_cluster is not None: mbar_cluster_type = llvm.PointerType.get(AddressSpace.dsmem) mbar_llvm_ptr = nvvm.mapa( @@ -150,7 +153,7 @@ def mbarrier_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> None: # This NVVM Op is a spin-loop wrapping the mbarrier.try_wait.parity.shared.b64 PTX # The timeout in ns only applies to the latter and this call is truly blocking nvvm.mbarrier_try_wait_parity_shared( - mbar_ptr.llvm_ptr, + mbar_ptr.to_llvm_ptr(loc=loc, ip=ip), Int32(phase).ir_value(loc=loc, ip=ip), Int32(timeout_ns).ir_value(loc=loc, ip=ip), loc=loc, @@ -174,7 +177,7 @@ def mbarrier_try_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> Bo return Boolean( nvvm.mbarrier_wait_parity( - mbar_ptr.llvm_ptr, + mbar_ptr.to_llvm_ptr(loc=loc, ip=ip), Int32(phase).ir_value(loc=loc, ip=ip), nvvm.MBarrierWaitKind.TRY, loc=loc, @@ -228,7 +231,7 @@ def mbarrier_arrive( the mbarrier is converted to a remote address in the peer CTA's SMEM. """ - mbar_llvm_ptr = mbar_ptr.llvm_ptr + mbar_llvm_ptr = mbar_ptr.to_llvm_ptr(loc=loc, ip=ip) if peer_cta_rank_in_cluster is not None: BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90) @@ -269,10 +272,5 @@ def cp_async_mbarrier_arrive_noinc(mbar_ptr: Pointer, *, loc=None, ip=None) -> N """ BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90) - mbar_llvm_ptr = mbar_ptr.llvm_ptr - nvvm.cp_async_mbarrier_arrive_shared( - mbar_llvm_ptr, - noinc=True, - loc=loc, - ip=ip, - ) + mbar_llvm_ptr = mbar_ptr.to_llvm_ptr(loc=loc, ip=ip) + nvvm.cp_async_mbarrier_arrive_shared(mbar_llvm_ptr, noinc=True, loc=loc, ip=ip) diff --git a/python/CuTeDSL/cutlass/cute/arch/numeric_conversion.py b/python/CuTeDSL/cutlass/cute/arch/numeric_conversion.py index ba9faa4e8..3d1a8cc83 100644 --- a/python/CuTeDSL/cutlass/cute/arch/numeric_conversion.py +++ b/python/CuTeDSL/cutlass/cute/arch/numeric_conversion.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py b/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py index b45b283ca..7835a7ca9 100644 --- a/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py +++ b/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA @@ -250,6 +250,26 @@ def block_idx_in_cluster(*, loc=None, ip=None) -> Int32: return Int32(nvvm.read_ptx_sreg_cluster_ctarank(T.i32(), loc=loc, ip=ip)) +@dsl_user_op +def dynamic_smem_size(*, loc=None, ip=None) -> Int32: + """ + Returns the launch dynamic smem size. + """ + return Int32( + llvm.inline_asm( + Int32.mlir_type, + [], + "mov.u32 $0, %dynamic_smem_size;\n", + "=r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + ) + + @dsl_user_op def shuffle_sync_op( value: Union[Numeric, "TensorSSA"], diff --git a/python/CuTeDSL/cutlass/cute/arch/smem.py b/python/CuTeDSL/cutlass/cute/arch/smem.py index 1a89db71e..ea5886a34 100644 --- a/python/CuTeDSL/cutlass/cute/arch/smem.py +++ b/python/CuTeDSL/cutlass/cute/arch/smem.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/arch/tmem.py b/python/CuTeDSL/cutlass/cute/arch/tmem.py index c5f637501..2b100aae9 100644 --- a/python/CuTeDSL/cutlass/cute/arch/tmem.py +++ b/python/CuTeDSL/cutlass/cute/arch/tmem.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/atom.py b/python/CuTeDSL/cutlass/cute/atom.py index 0d4ec36e9..0cda9aacf 100644 --- a/python/CuTeDSL/cutlass/cute/atom.py +++ b/python/CuTeDSL/cutlass/cute/atom.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/core.py b/python/CuTeDSL/cutlass/cute/core.py index 47339465f..e37152321 100644 --- a/python/CuTeDSL/cutlass/cute/core.py +++ b/python/CuTeDSL/cutlass/cute/core.py @@ -3,15 +3,17 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA # is strictly prohibited. from functools import partial, reduce +import inspect from inspect import isclass from typing import Any, Dict, List, Optional, Tuple, Type, Union, overload +from types import MethodType from cutlass import const_expr from typing_extensions import deprecated @@ -31,6 +33,8 @@ from cutlass._mlir.dialects.cute import ( from cutlass.cutlass_dsl import ( T, const, + and_, + as_numeric, cutlass_arith, dsl_user_op, extract_mlir_values, @@ -704,7 +708,7 @@ class ScaledBasis: def __eq__(self, other): if isinstance(other, ScaledBasis): - return self.value == other.value and self.mode == other.mode + return and_(self.mode == other.mode, self.value == other.value) else: return False @@ -1212,6 +1216,10 @@ class _ComposedLayout(ComposedLayout): @property @dsl_user_op def shape(self, *, loc=None, ip=None) -> Shape: + return self.shape_method(loc=loc, ip=ip) + + @dsl_user_op + def shape_method(self, *, loc=None, ip=None) -> Shape: return _unpack_x_tuple( _cute_ir.get_shape(self.value, loc=loc, ip=ip), loc=loc, ip=ip ) @@ -1352,6 +1360,60 @@ class _Pointer(Pointer): def type(self) -> ir.Type: return self.value.type + @dsl_user_op + def load(self, *, loc=None, ip=None) -> Numeric: + # LLVM doesn't support load/store narrow precision per element + tmp_ty = self.dtype.mlir_type + if self.dtype is Boolean or self.dtype.width == 8: + tmp_ty = T.i8() + elif self.dtype.width < 8: + raise ValueError( + f"Loading narrow precision type {self.dtype} is not supported" + ) + + llvm_ptr = self.to_llvm_ptr(loc=loc, ip=ip) + tmp_val = llvm.load(tmp_ty, llvm_ptr, loc=loc, ip=ip) + if self.dtype.width == 8: + tmp_val = arith.bitcast(self.dtype.mlir_type, tmp_val, loc=loc, ip=ip) + + return self.dtype(tmp_val, loc=loc, ip=ip) + + @dsl_user_op + def store( + self, + value: Union[Numeric, cutlass_arith.ArithValue, int, float, bool], + *, + loc=None, + ip=None, + ): + if isinstance(value, (int, float, bool, cutlass_arith.ArithValue)): + value = self.dtype(value, loc=loc, ip=ip) + elif isinstance(value, Numeric): + if value.dtype is not self.dtype: + value = value.to(self.dtype, loc=loc, ip=ip) + else: + raise ValueError(f"Unsupported value type: {type(value)}") + # LLVM doesn't support load/store narrow precision per element + tmp_val = value.ir_value(loc=loc, ip=ip) + if self.dtype.width == 8: + tmp_val = arith.bitcast(T.i8(), tmp_val, loc=loc, ip=ip) + elif self.dtype is not Boolean and self.dtype.width < 8: + raise ValueError( + f"Storing narrow precision type {self.dtype} is not supported" + ) + + llvm_ptr = self.to_llvm_ptr(loc=loc, ip=ip) + return llvm.store(tmp_val, llvm_ptr, loc=loc, ip=ip) + + @dsl_user_op + def __getitem__(self, idx: Int, *, loc=None, ip=None) -> Pointer: + return (self + idx).load() + + @dsl_user_op + def __setitem__(self, idx: Int, value: Numeric, *, loc=None, ip=None) -> Pointer: + (self + idx).store(value, loc=loc, ip=ip) + return value + # Only use if you absolutely need to get the LLVM pointer Value @property @dsl_user_op @@ -1360,6 +1422,25 @@ class _Pointer(Pointer): """ Get the LLVM pointer representation of this pointer. + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR, defaults to None + :type ip: Optional[InsertionPoint] + :return: The LLVM pointer representation + :rtype: ir.Value + """ + return self.to_llvm_ptr(loc=loc, ip=ip) + + @dsl_user_op + @lru_cache_ir() + def to_llvm_ptr(self, *, loc=None, ip=None) -> ir.Value: + """ + Get the LLVM pointer representation of this pointer. (Used by internal API to propagate loc and ip) + + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR, defaults to None + :type ip: Optional[InsertionPoint] :return: The LLVM pointer representation :rtype: ir.Value """ @@ -1587,7 +1668,7 @@ def pretty_str(arg) -> str: @dsl_user_op -def printf(*args, loc=None, ip=None) -> None: +def printf(*args, loc=None, ip=None, end="\n") -> None: """ Print one or more values with optional formatting. @@ -1607,6 +1688,8 @@ def printf(*args, loc=None, ip=None) -> None: :type loc: Optional[Location] :param ip: Insertion point for code generation, defaults to None :type ip: Optional[InsertionPoint] + :param end: Suffix for the printed value, defaults to newline + :type end: Optional[str] :raises ValueError: If no arguments are provided :raises TypeError: If an unsupported argument type is passed @@ -1636,10 +1719,10 @@ def printf(*args, loc=None, ip=None) -> None: raise ValueError("expects at least one argument to print") if isinstance(args[0], str): - fmt = args[0] + "\n" + fmt = args[0] + end args = args[1:] else: - fmt = "{}" + ", {}" * (len(args) - 1) + "\n" + fmt = "{}" + ", {}" * (len(args) - 1) + end def process_arg(arg): arg0 = arg.value if isinstance(arg, Numeric) else arg @@ -3384,7 +3467,7 @@ def recast_ptr( if cvt_type is None: if not isclass(dtype) or not issubclass(dtype, Numeric): raise TypeError(f"dtype must be a type of Numeric, but got {dtype}") - cvt_type = dtype.mlir_type + cvt_type = T.i8() if dtype is Boolean else dtype.mlir_type dtype = cvt_type value_type = ptr.type.value_type if dtype is None else dtype @@ -4287,8 +4370,8 @@ class struct: storage = allocator.allocate(StorageB) storage.a[0] ... - storage.x ... - storage.compA.real ... + storage.x.ptr ... + storage.compA.real.ptr ... :param cls: The struct class with annotations. :return: The decorated struct class. @@ -4306,8 +4389,8 @@ class struct: :ivar _size: The size of the MemRange. """ - _dtype = None - _size = None + _dtype: Optional[Numeric] = None + _size: Optional[int] = None def __new__(cls, name, bases, dct): new_cls = super().__new__(cls, name, bases, dct) @@ -4337,7 +4420,7 @@ class struct: @property def elem_width(cls): - return cls._dtype.width + return cls._dtype.width if cls._dtype is not Boolean else 8 @property def size_in_bytes(cls): @@ -4368,12 +4451,15 @@ class struct: case the range can only be used for its address (e.g. as a partition marker). :param base: The base address of the memory range. """ - self._dtype = dtype - self._size = size - self._base = base + self._dtype: Optional[Numeric] = dtype + self._size: Optional[int] = size + self._base: Optional[Pointer] = base + + def __repr__(self): + return f"{object.__repr__(self)} " @dsl_user_op - def data_ptr(self, *, loc=None, ip=None): + def data_ptr(self, *, loc=None, ip=None) -> Pointer: """ Returns start pointer to the data in this memory range. @@ -4384,7 +4470,9 @@ class struct: return recast_ptr(self._base, dtype=self._dtype, loc=loc, ip=ip) @dsl_user_op - def get_tensor(self, layout, swizzle=None, dtype=None, *, loc=None, ip=None): + def get_tensor( + self, layout, swizzle=None, dtype=None, *, loc=None, ip=None + ) -> Tensor: """ Creates a tensor from the memory range. @@ -4404,9 +4492,10 @@ class struct: elem_type = self._dtype if dtype is None else dtype ptr = recast_ptr(self._base, swizzle, dtype=elem_type, loc=loc, ip=ip) res = make_tensor(ptr, layout, loc=loc, ip=ip) - return res + return type(res)(res, dtype=elem_type, loc=loc, ip=ip) - def __getitem__(self, index: int) -> Any: + @dsl_user_op + def __getitem__(self, index: int, *, loc=None, ip=None) -> Any: """ Returns the element at the specified index in the memory range. @@ -4415,7 +4504,21 @@ class struct: :raises AssertionError: If the index is out of range. """ assert (index >= 0) and (index < self._size) - return self.data_ptr() + index + ptr = self.data_ptr() + index + return ptr.load(loc=loc, ip=ip) + + @dsl_user_op + def __setitem__(self, index: int, val, *, loc=None, ip=None): + """ + Set element value at the specified index in the memory range. + + :param index: The index of the element to retrieve. + :val: The element value at the specified index. + :raises AssertionError: If the index is out of range. + """ + assert (index >= 0) and (index < self._size) + ptr = self.data_ptr() + index + ptr.store(as_numeric(val).to(self._dtype), loc=loc, ip=ip) # inner class for aligning a member type class _AlignMeta(type): @@ -4430,8 +4533,8 @@ class struct: :ivar _align: The alignment of the data type. """ - _dtype = None - _align = None + _dtype: Optional[Any] = None + _align: Optional[int] = None def __new__(cls, name, bases, dct): return super().__new__(cls, name, bases, dct) @@ -4473,6 +4576,88 @@ class struct: pass + class _ScalarData(_Pointer): + """ + Represents a scalar value at a given pointer location in memory. + + This class provides utility methods to get a scalar pointer. + It wraps a pointer to a scalar element and enables element-wise memory operations. + + :ivar _ptr: The underlying pointer to the scalar value. + """ + + def __init__(self, ptr): + self._ptr: Optional[_Pointer] = ptr + + def __repr__(self): + return f"{object.__repr__(self)} <{self.dtype}> " + + def __get_mlir_types__(self) -> List[ir.Type]: + return [self.value.type] + + def __extract_mlir_values__(self) -> List[ir.Value]: + return [self.value] + + def __new_from_mlir_values__(self, values) -> Pointer: + ptr = _Pointer( + values[0] if isinstance(values[0], ir.Value) else values[0].value + ) + return self.__class__(ptr) + + @dsl_user_op + def to_llvm_ptr(self, *, loc=None, ip=None) -> ir.Value: + llvm_ptr_ty = llvm.PointerType.get( + self._ptr.memspace.value + if self._ptr.memspace != AddressSpace.rmem + else 0 + ) + return builtin.unrealized_conversion_cast( + [llvm_ptr_ty], [self.value], loc=loc, ip=ip + ) + + @property + def ptr(self) -> Pointer: + """ + Get the underlying pointer. + + :return: The pointer to the scalar value. + :rtype: Pointer + """ + return self._ptr + + @property + def dtype(self) -> Numeric: + """ + Get the data type of the scalar value. + + :return: The numeric data type of the underlying pointer. + :rtype: Numeric + """ + return self._ptr.dtype + + @property + @deprecated("Using `struct.scalar` as pointer is deprecated.") + def value(self): + """ + Get the raw MLIR value of the underlying pointer. + + .. deprecated:: + Using ``struct.scalar`` as pointer is deprecated. + Use explicit ``struct.scalar.ptr`` for pointer instead. + + :return: The MLIR value of the underlying pointer. + :rtype: ir.Value + """ + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "Use explicit `struct.scalar.ptr` for pointer instead.", + DeprecationWarning, + ) + return self._ptr.value + # util func for base dsl scalar types @staticmethod def _is_scalar_type(dtype): @@ -4493,55 +4678,72 @@ class struct: :raises TypeError: If the struct is empty. """ self._cls = cls - self.__name__ = f"struct::{cls.__name__}" + self.__name__ = f"cute.struct::{cls.__name__}" # Get the class annotations self._annotations = getattr(cls, "__annotations__", {}) # Create a dictionary to store the offsets self._offsets: Dict[str, int] = {} + # Override `setattr` function for struct to assign scalar properly + def struct_setattr(self, name, value): + attr = getattr(self, name, None) + if isinstance(attr, struct._ScalarData): + value = as_numeric(value).to(attr.dtype) + attr.ptr.store(value) + else: + raise ValueError(f"cannot assign value to `{name}` in {self.__name__}") + + type.__setattr__(self._cls, "__setattr__", struct_setattr) + + # Override `__repr__` function for struct info + def struct_repr(self): + return f"{object.__repr__(self)} <{self.__name__}> " + + self._cls.__repr__ = struct_repr + # Calculate the offsets and alignment offset = 0 alignment = 1 if len(self._annotations) == 0: raise TypeError("Empty struct is not supported!") - for name, object in self._annotations.items(): - # get alignment of object + for name, member in self._annotations.items(): + # get alignment of member sub_align = 1 - if isinstance(object, struct._AlignMeta): - sub_align = object.align - object = object.dtype + if isinstance(member, struct._AlignMeta): + sub_align = member.align + member = member.dtype # switch addition order to support dynamic size def add_offset(val): return val + offset if isinstance(val, ir.Value) else offset + val # size of scalar - if struct._is_scalar_type(object): - dtype_size = max(1, object.width // 8) + if struct._is_scalar_type(member): + dtype_size = max(1, member.width // 8) sub_align = max(dtype_size, sub_align) offset = self.align_offset(offset, sub_align) self._offsets[name] = offset offset = add_offset(dtype_size) # size of array is size_in_bytes, alignment is elem_size - elif isinstance(object, struct._MemRangeMeta): + elif isinstance(member, struct._MemRangeMeta): # Allow empty array as a free marker-only struct member. # Use max(sub_align, ) because we might have in the future some - # object.elem_width less than 8, such as fp4, bit and others, + # member.elem_width less than 8, such as fp4, bit and others, # and align_offset() does not support an alignment of 0. - sub_align = max(object.elem_width // 8, sub_align) + sub_align = max(member.elem_width // 8, sub_align) offset = self.align_offset(offset, sub_align) self._offsets[name] = offset - offset = add_offset(object.size_in_bytes) + offset = add_offset(member.size_in_bytes) # size of struct - elif isinstance(object, struct): - sub_align = max(object.__alignof__(), sub_align) + elif isinstance(member, struct): + sub_align = max(member.__alignof__(), sub_align) offset = self.align_offset(offset, sub_align) self._offsets[name] = offset - offset = add_offset(object.__sizeof__()) + offset = add_offset(member.__sizeof__()) else: raise TypeError( f"Struct element only support struct/array/base_dsl scalar, " - f"but got {object}" + f"but got {member}" ) # Total alignment determined by the strictest requirement alignment = max(alignment, sub_align) @@ -4564,20 +4766,22 @@ class struct: # make an new object of user-defined decorated struct # otherwise it will override same self._cls when new instance created cls = self._cls() - setattr(cls, "_base", base) + object.__setattr__(cls, "base", base) + object.__setattr__(cls, "__name__", self.__name__) for name, off in self._offsets.items(): obj = self._annotations[name] if isinstance(obj, struct._AlignMeta): obj = obj.dtype if struct._is_scalar_type(obj): - new_obj = recast_ptr(base + off, dtype=obj, loc=loc, ip=ip) - setattr(cls, name, new_obj) + ptr = recast_ptr(base + off, dtype=obj, loc=loc, ip=ip) + new_obj = struct._ScalarData(ptr) + object.__setattr__(cls, name, new_obj) elif isinstance(obj, struct._MemRangeMeta): new_obj = struct._MemRangeData(obj._dtype, obj._size, base + off) - setattr(cls, name, new_obj) + object.__setattr__(cls, name, new_obj) elif isinstance(obj, struct): new_obj = obj(base + off) - setattr(cls, name, new_obj) + object.__setattr__(cls, name, new_obj) else: raise TypeError( f"Struct element only support struct/array/base_dsl scalar, " @@ -4614,6 +4818,196 @@ class struct: return (offset + (align - 1)) & ~(align - 1) +############################################################################## +# User defined struct +############################################################################## + + +class union(struct): + """ + Decorator to abstract C union in Python DSL. + + Similar to cute.struct, but lays out objects as a union: + - All objects start at offset 0 + - The alignment is the maximum alignment of all objects + - The size is the maximum size of all objects + + **Usage:**Expand commentComment on line R4131 + + .. code-block:: python + + # Define a union with scalar int/float elements: + @cute.union + class value_union: + as_int : cutlass.Int32 + as_float : cutlass.Float32 + + + @cute.union + class data_union: + small : cutlass.Int16 + medium : cutlass.Int32 + large : cutlass.Int64 + + + # Supports alignment for its elements: + @cute.union + class aligned_union: + a: cute.struct.Align[cutlass.Float32, 16] + b: cute.struct.Align[cutlass.Int32, 8] + + + # Statically get size and alignment: + size = data_union.__sizeof__() + align = data_union.__alignof__() + + # Allocate and reference elements: + allocator = cutlass.utils.SmemAllocator() + value = allocator.allocate(data_union) + + # Access union members (all at the same offset): + value.small.ptr ... + value.medium.ptr ... + value.large.ptr ... + + :param cls: The union class with annotations. + :return: The decorated union class. + """ + + def __init__(self, cls): + """ + Initializes a new cute.union decorator instance. + + :param cls: The class representing the union data type. + :raises TypeError: If the union is empty. + """ + object.__setattr__(self, "_cls", cls) + object.__setattr__(self, "__name__", f"cute.union::{cls.__name__}") + # Get the class annotations + object.__setattr__(self, "_annotations", getattr(cls, "__annotations__", {})) + # Create a dictionary to store the offsets (all zeros for union) + object.__setattr__(self, "_offsets", {}) + + # Override `setattr` function for struct to assign scalar properly + def union_setattr(self, name, value): + attr = getattr(self, name, None) + if isinstance(attr, struct._ScalarData): + value = as_numeric(value).to(attr.dtype) + attr.ptr.store(value) + else: + raise ValueError(f"cannot assign value to `{name}` in {self.__name__}") + + type.__setattr__(self._cls, "__setattr__", union_setattr) + + # Override `__repr__` function for struct info + def union_repr(self): + return f"{object.__repr__(self)} <{self.__name__}> " + + type.__setattr__(self._cls, "__repr__", union_repr) + + # Calculate the maximum size and alignment + max_size = 0 + max_alignment = 1 + if len(self._annotations) == 0: + raise TypeError("Empty union is not supported!") + for name, item in self._annotations.items(): + # All offsets are 0 for a union + self._offsets[name] = 0 + + # Get alignment of object + sub_align = 1 + if isinstance(item, struct._AlignMeta): + sub_align = item.align + item = item.dtype + + # Calculate size and alignment based on object type + if struct._is_scalar_type(item): + dtype_size = max(1, item.width // 8) + sub_align = max(dtype_size, sub_align) + max_size = max(max_size, dtype_size) + elif isinstance(item, struct._MemRangeMeta): + sub_align = max(item.elem_width // 8, sub_align) + max_size = max(max_size, item.size_in_bytes) + elif isinstance(item, struct): + sub_align = max(item.__alignof__(), sub_align) + max_size = max(max_size, item.__sizeof__()) + else: + raise TypeError( + f"Union element only support struct/array/DSL scalar, " + f"but got `{item.__qualname__}`" + ) + # Union alignment is the maximum alignment of all members + max_alignment = max(max_alignment, sub_align) + + # Union size is the maximum size, aligned to the maximum alignment + object.__setattr__(self, "_align_of", max_alignment) + object.__setattr__( + self, "_size_of", struct.align_offset(max_size, max_alignment) + ) + + @dsl_user_op + def __call__(self, base: Any, *, loc=None, ip=None) -> None: + """ + Creates a new instance of the decorated union. + + :param base: The base address of the union. + :return: An instance of the decorated union. + :raises TypeError: If the base pointer is not byte-sized. + """ + if base.type.value_type.width != 8: + raise TypeError("union base ptr value type must be byte sized.") + # Make a new object of user-defined decorated union + cls = self._cls() + object.__setattr__(cls, "base", base) + object.__setattr__(cls, "__name__", self.__name__) + for name, off in self._offsets.items(): + obj = self._annotations[name] + if isinstance(obj, struct._AlignMeta): + obj = obj.dtype + if struct._is_scalar_type(obj): + ptr = recast_ptr(base + off, dtype=obj, loc=loc, ip=ip) + new_obj = struct._ScalarData(ptr) + object.__setattr__(cls, name, new_obj) + elif isinstance(obj, struct._MemRangeMeta): + new_obj = struct._MemRangeData(obj._dtype, obj._size, base + off) + object.__setattr__(cls, name, new_obj) + elif isinstance(obj, struct): + new_obj = obj(base + off) + object.__setattr__(cls, name, new_obj) + else: + raise TypeError( + f"Union element only support struct/array/DSL scalar, " + f"but got `{obj.__qualname__}`" + ) + return cls + + def __setattr__(self, name, value): + raise TypeError("Cannot add a new field after initialization") + def size_in_bytes(self) -> int: + """ + Returns the size of the union in bytes. + + :return: The size of the union. + """ + return self._size_of + + def __sizeof__(self) -> int: + """ + Returns the size of the union in bytes. + + :return: The size of the union. + """ + return self._size_of + + def __alignof__(self) -> int: + """ + Returns the alignment of the union in bytes. + + :return: The alignment of the union. + """ + return self._align_of + + # Deprecated usage but keep them to avoid breaking some examples uses `cute.core.ThrMma` from .atom import ThrCopy as _ThrCopy @@ -4752,6 +5146,23 @@ class FastDivmodDivisor: return f"FastDivmodDivisor({self._divisor.type})" +# Set explicit signature for Sphinx documentation to avoid issues with @dsl_user_op decorator +FastDivmodDivisor.__init__.__signature__ = inspect.Signature( + [ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter( + "divisor", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Integer + ), + inspect.Parameter( + "is_power_of_2", + inspect.Parameter.POSITIONAL_OR_KEYWORD, + default=None, + annotation=bool, + ), + ] +) + + @dsl_user_op def fast_divmod_create_divisor( divisor: Integer, *, loc=None, ip=None diff --git a/python/CuTeDSL/cutlass/cute/experimental/__init__.py b/python/CuTeDSL/cutlass/cute/experimental/__init__.py index bee9202d7..571629656 100644 --- a/python/CuTeDSL/cutlass/cute/experimental/__init__.py +++ b/python/CuTeDSL/cutlass/cute/experimental/__init__.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/experimental/algorithm.py b/python/CuTeDSL/cutlass/cute/experimental/algorithm.py index bf8d7cf48..90766ca98 100644 --- a/python/CuTeDSL/cutlass/cute/experimental/algorithm.py +++ b/python/CuTeDSL/cutlass/cute/experimental/algorithm.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/experimental/math.py b/python/CuTeDSL/cutlass/cute/experimental/math.py index 92a98471d..5eedee9c5 100644 --- a/python/CuTeDSL/cutlass/cute/experimental/math.py +++ b/python/CuTeDSL/cutlass/cute/experimental/math.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/experimental/memory.py b/python/CuTeDSL/cutlass/cute/experimental/memory.py index 5294e09ad..9d36437d1 100644 --- a/python/CuTeDSL/cutlass/cute/experimental/memory.py +++ b/python/CuTeDSL/cutlass/cute/experimental/memory.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/experimental/utils.py b/python/CuTeDSL/cutlass/cute/experimental/utils.py index 5a9c72bef..cfcce7288 100644 --- a/python/CuTeDSL/cutlass/cute/experimental/utils.py +++ b/python/CuTeDSL/cutlass/cute/experimental/utils.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/export/__init__.py b/python/CuTeDSL/cutlass/cute/export/__init__.py index 1b9546ba2..8b2fe7aea 100644 --- a/python/CuTeDSL/cutlass/cute/export/__init__.py +++ b/python/CuTeDSL/cutlass/cute/export/__init__.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/export/aot_config.py b/python/CuTeDSL/cutlass/cute/export/aot_config.py index 5922b91cf..3c6c62025 100644 --- a/python/CuTeDSL/cutlass/cute/export/aot_config.py +++ b/python/CuTeDSL/cutlass/cute/export/aot_config.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/export/c_header_generator.py b/python/CuTeDSL/cutlass/cute/export/c_header_generator.py index 7d5ff0e5b..762ce1222 100644 --- a/python/CuTeDSL/cutlass/cute/export/c_header_generator.py +++ b/python/CuTeDSL/cutlass/cute/export/c_header_generator.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA @@ -195,7 +195,7 @@ typedef struct {{ packed_args.append("&" + arg_name) else: raise DSLRuntimeError( - f"Unsupported argument for c function argument generation: {arg} with type {arg_type}" + f"Unsupported argument for c function argument generation: {arg_name} = {arg} with type annotation {arg_type}" ) return arguments, packed_args, declarations diff --git a/python/CuTeDSL/cutlass/cute/export/export.py b/python/CuTeDSL/cutlass/cute/export/export.py index fa27ce8b4..9375c954e 100644 --- a/python/CuTeDSL/cutlass/cute/export/export.py +++ b/python/CuTeDSL/cutlass/cute/export/export.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/export/load.py b/python/CuTeDSL/cutlass/cute/export/load.py index 00e1a9b25..4eb85a351 100644 --- a/python/CuTeDSL/cutlass/cute/export/load.py +++ b/python/CuTeDSL/cutlass/cute/export/load.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/ffi.py b/python/CuTeDSL/cutlass/cute/ffi.py index 66e18a47b..727d1a296 100644 --- a/python/CuTeDSL/cutlass/cute/ffi.py +++ b/python/CuTeDSL/cutlass/cute/ffi.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/math.py b/python/CuTeDSL/cutlass/cute/math.py index 16a511a6e..8dab6b981 100644 --- a/python/CuTeDSL/cutlass/cute/math.py +++ b/python/CuTeDSL/cutlass/cute/math.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA @@ -15,6 +15,7 @@ from .typing import Numeric from .tensor import TensorSSA from cutlass._mlir.dialects import math, arith +from cutlass.cutlass_dsl import dsl_user_op def _math_op(func: Callable, fastmath: bool, *args, **kwargs): @@ -22,7 +23,7 @@ def _math_op(func: Callable, fastmath: bool, *args, **kwargs): :param func: The function to dispatch :param args: The input tensor or scalar - :param kwargs: The input tensor or scalar + :param kwargs: Extra keyword arguments (loc, ip) forwarded to the MLIR op """ arg_type = type(args[0]) for arg in args: @@ -40,15 +41,16 @@ def _math_op(func: Callable, fastmath: bool, *args, **kwargs): fastmath_flag = arith.FastMathFlags.fast if fastmath else arith.FastMathFlags.none if isinstance(args[0], TensorSSA): return TensorSSA( - func(*args, fastmath=fastmath_flag), args[0].shape, args[0].dtype + func(*args, fastmath=fastmath_flag, **kwargs), args[0].shape, args[0].dtype ) else: args = [a.ir_value() for a in args] - return func(*args, fastmath=fastmath_flag) + return func(*args, fastmath=fastmath_flag, **kwargs) +@dsl_user_op def acos( - a: Union[TensorSSA, Numeric], fastmath: bool = False + a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None ) -> Union[TensorSSA, Numeric]: """Compute element-wise arc cosine of the input tensor. @@ -56,6 +58,10 @@ def acos( :type a: Union[TensorSSA, Numeric] :param fastmath: Enable fast math optimizations, defaults to False :type fastmath: bool, optional + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] :return: Tensor containing the arc cosine of each element in input tensor :rtype: Union[TensorSSA, Numeric] @@ -67,11 +73,12 @@ def acos( y = x.load() # Load values z = acos(y) # Compute arc cosine """ - return _math_op(math.acos, fastmath, a) + return _math_op(math.acos, fastmath, a, loc=loc, ip=ip) +@dsl_user_op def asin( - a: Union[TensorSSA, Numeric], fastmath: bool = False + a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None ) -> Union[TensorSSA, Numeric]: """Compute element-wise arc sine of the input tensor. @@ -79,6 +86,10 @@ def asin( :type a: Union[TensorSSA, Numeric] :param fastmath: Enable fast math optimizations, defaults to False :type fastmath: bool, optional + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] :return: Tensor containing the arc sine of each element in input tensor :rtype: Union[TensorSSA, Numeric] @@ -90,11 +101,12 @@ def asin( y = x.load() # Load values z = asin(y) # Compute arc sine """ - return _math_op(math.asin, fastmath, a) + return _math_op(math.asin, fastmath, a, loc=loc, ip=ip) +@dsl_user_op def atan( - a: Union[TensorSSA, Numeric], fastmath: bool = False + a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None ) -> Union[TensorSSA, Numeric]: """Compute element-wise arc tangent of the input tensor. @@ -102,6 +114,10 @@ def atan( :type a: Union[TensorSSA, Numeric] :param fastmath: Enable fast math optimizations, defaults to False :type fastmath: bool, optional + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] :return: Tensor containing the arc tangent of each element in input tensor :rtype: Union[TensorSSA, Numeric] @@ -113,11 +129,13 @@ def atan( y = x.load() # Load values z = atan(y) # Compute arc tangent """ - return _math_op(math.atan, fastmath, a) + return _math_op(math.atan, fastmath, a, loc=loc, ip=ip) +@dsl_user_op def atan2( - a: Union[TensorSSA, Numeric], b: Union[TensorSSA, Numeric], fastmath: bool = False + a: Union[TensorSSA, Numeric], b: Union[TensorSSA, Numeric], fastmath: bool = False, + *, loc=None, ip=None ) -> Union[TensorSSA, Numeric]: """Compute element-wise arc tangent of two tensors. @@ -130,6 +148,10 @@ def atan2( :type b: Union[TensorSSA, Numeric] :param fastmath: Enable fast math optimizations, defaults to False :type fastmath: bool, optional + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] :return: Tensor containing the arc tangent of a/b element-wise :rtype: Union[TensorSSA, Numeric] @@ -141,11 +163,12 @@ def atan2( x = cute.make_rmem_tensor(ptr2, layout).load() # x coordinates theta = atan2(y, x) # Compute angles """ - return _math_op(math.atan2, fastmath, a, b) + return _math_op(math.atan2, fastmath, a, b, loc=loc, ip=ip) +@dsl_user_op def cos( - a: Union[TensorSSA, Numeric], fastmath: bool = False + a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None ) -> Union[TensorSSA, Numeric]: """Compute element-wise cosine of the input tensor. @@ -153,6 +176,10 @@ def cos( :type a: Union[TensorSSA, Numeric] :param fastmath: Enable fast math optimizations, defaults to False :type fastmath: bool, optional + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] :return: Tensor containing the cosine of each element :rtype: Union[TensorSSA, Numeric] @@ -164,11 +191,12 @@ def cos( y = x.load() # Load values z = cos(y) # Compute cosine """ - return _math_op(math.cos, fastmath, a) + return _math_op(math.cos, fastmath, a, loc=loc, ip=ip) +@dsl_user_op def erf( - a: Union[TensorSSA, Numeric], fastmath: bool = False + a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None ) -> Union[TensorSSA, Numeric]: """Compute element-wise error function of the input tensor. @@ -179,6 +207,10 @@ def erf( :type a: Union[TensorSSA, Numeric] :param fastmath: Enable fast math optimizations, defaults to False :type fastmath: bool, optional + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] :return: Tensor containing the error function value for each element :rtype: Union[TensorSSA, Numeric] @@ -190,11 +222,12 @@ def erf( y = x.load() # Load values z = erf(y) # Compute error function """ - return _math_op(math.erf, fastmath, a) + return _math_op(math.erf, fastmath, a, loc=loc, ip=ip) +@dsl_user_op def exp( - a: Union[TensorSSA, Numeric], fastmath: bool = False + a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None ) -> Union[TensorSSA, Numeric]: """Compute element-wise exponential of the input tensor. @@ -202,6 +235,10 @@ def exp( :type a: Union[TensorSSA, Numeric] :param fastmath: Enable fast math optimizations, defaults to False :type fastmath: bool, optional + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] :return: Tensor containing the exponential of each element :rtype: Union[TensorSSA, Numeric] @@ -213,11 +250,12 @@ def exp( y = x.load() # Load values z = exp(y) # Compute exponential """ - return _math_op(math.exp, fastmath, a) + return _math_op(math.exp, fastmath, a, loc=loc, ip=ip) +@dsl_user_op def exp2( - a: Union[TensorSSA, Numeric], fastmath: bool = False + a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None ) -> Union[TensorSSA, Numeric]: """Compute element-wise base-2 exponential of the input tensor. @@ -225,6 +263,10 @@ def exp2( :type a: Union[TensorSSA, Numeric] :param fastmath: Enable fast math optimizations, defaults to False :type fastmath: bool, optional + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] :return: Tensor containing 2 raised to the power of each element :rtype: Union[TensorSSA, Numeric] @@ -236,11 +278,12 @@ def exp2( y = x.load() # Load values z = exp2(y) # Compute 2^x """ - return _math_op(math.exp2, fastmath, a) + return _math_op(math.exp2, fastmath, a, loc=loc, ip=ip) +@dsl_user_op def log( - a: Union[TensorSSA, Numeric], fastmath: bool = False + a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None ) -> Union[TensorSSA, Numeric]: """Compute element-wise natural logarithm of the input tensor. @@ -248,6 +291,10 @@ def log( :type a: Union[TensorSSA, Numeric] :param fastmath: Enable fast math optimizations, defaults to False :type fastmath: bool, optional + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] :return: Tensor containing the natural logarithm of each element :rtype: Union[TensorSSA, Numeric] @@ -259,11 +306,12 @@ def log( y = x.load() # Load values z = log(y) # Compute natural logarithm """ - return _math_op(math.log, fastmath, a) + return _math_op(math.log, fastmath, a, loc=loc, ip=ip) +@dsl_user_op def log2( - a: Union[TensorSSA, Numeric], fastmath: bool = False + a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None ) -> Union[TensorSSA, Numeric]: """Compute element-wise base-2 logarithm of the input tensor. @@ -271,6 +319,10 @@ def log2( :type a: Union[TensorSSA, Numeric] :param fastmath: Enable fast math optimizations, defaults to False :type fastmath: bool, optional + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] :return: Tensor containing the base-2 logarithm of each element :rtype: Union[TensorSSA, Numeric] @@ -282,11 +334,12 @@ def log2( y = x.load() # Load values z = log2(y) # Compute log base 2 """ - return _math_op(math.log2, fastmath, a) + return _math_op(math.log2, fastmath, a, loc=loc, ip=ip) +@dsl_user_op def log10( - a: Union[TensorSSA, Numeric], fastmath: bool = False + a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None ) -> Union[TensorSSA, Numeric]: """Compute element-wise base-10 logarithm of the input tensor. @@ -294,6 +347,10 @@ def log10( :type a: Union[TensorSSA, Numeric] :param fastmath: Enable fast math optimizations, defaults to False :type fastmath: bool, optional + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] :return: Tensor containing the base-10 logarithm of each element :rtype: Union[TensorSSA, Numeric] @@ -305,11 +362,12 @@ def log10( y = x.load() # Load values z = log10(y) # Compute log base 10 """ - return _math_op(math.log10, fastmath, a) + return _math_op(math.log10, fastmath, a, loc=loc, ip=ip) +@dsl_user_op def rsqrt( - a: Union[TensorSSA, Numeric], fastmath: bool = False + a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None ) -> Union[TensorSSA, Numeric]: """Compute element-wise reciprocal square root of the input tensor. @@ -319,6 +377,10 @@ def rsqrt( :type a: Union[TensorSSA, Numeric] :param fastmath: Enable fast math optimizations, defaults to False :type fastmath: bool, optional + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] :return: Tensor containing the reciprocal square root of each element :rtype: Union[TensorSSA, Numeric] @@ -330,11 +392,12 @@ def rsqrt( y = x.load() # Load values z = rsqrt(y) # Compute 1/√x """ - return _math_op(math.rsqrt, fastmath, a) + return _math_op(math.rsqrt, fastmath, a, loc=loc, ip=ip) +@dsl_user_op def sin( - a: Union[TensorSSA, Numeric], fastmath: bool = False + a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None ) -> Union[TensorSSA, Numeric]: """Compute element-wise sine of the input tensor. @@ -342,6 +405,10 @@ def sin( :type a: Union[TensorSSA, Numeric] :param fastmath: Enable fast math optimizations, defaults to False :type fastmath: bool, optional + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] :return: Tensor containing the sine of each element :rtype: Union[TensorSSA, Numeric] @@ -353,11 +420,12 @@ def sin( y = x.load() # Load values z = sin(y) # Compute sine """ - return _math_op(math.sin, fastmath, a) + return _math_op(math.sin, fastmath, a, loc=loc, ip=ip) +@dsl_user_op def sqrt( - a: Union[TensorSSA, Numeric], fastmath: bool = False + a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None ) -> Union[TensorSSA, Numeric]: """Compute element-wise square root of the input tensor. @@ -365,6 +433,10 @@ def sqrt( :type a: Union[TensorSSA, Numeric] :param fastmath: Enable fast math optimizations, defaults to False :type fastmath: bool, optional + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] :return: Tensor containing the square root of each element :rtype: Union[TensorSSA, Numeric] @@ -376,11 +448,12 @@ def sqrt( y = x.load() # Load values z = sqrt(y) # Compute square root """ - return _math_op(math.sqrt, fastmath, a) + return _math_op(math.sqrt, fastmath, a, loc=loc, ip=ip) +@dsl_user_op def tan( - a: Union[TensorSSA, Numeric], fastmath: bool = False + a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None ) -> Union[TensorSSA, Numeric]: """Compute element-wise tangent of the input tensor. @@ -388,6 +461,10 @@ def tan( :type a: Union[TensorSSA, Numeric] :param fastmath: Enable fast math optimizations, defaults to False :type fastmath: bool, optional + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] :return: Tensor containing the tangent of each element :rtype: Union[TensorSSA, Numeric] @@ -399,11 +476,12 @@ def tan( y = x.load() # Load values z = tan(y) # Compute tangent """ - return _math_op(math.tan, fastmath, a) + return _math_op(math.tan, fastmath, a, loc=loc, ip=ip) +@dsl_user_op def tanh( - a: Union[TensorSSA, Numeric], fastmath: bool = False + a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None ) -> Union[TensorSSA, Numeric]: """Compute element-wise hyperbolic tangent of the input tensor. @@ -411,6 +489,10 @@ def tanh( :type a: Union[TensorSSA, Numeric] :param fastmath: Enable fast math optimizations, defaults to False :type fastmath: bool, optional + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] :return: Tensor containing the hyperbolic tangent of each element :rtype: Union[TensorSSA, Numeric] @@ -422,7 +504,7 @@ def tanh( y = x.load() # Load values z = tanh(y) # Compute hyperbolic tangent """ - return _math_op(math.tanh, fastmath, a) + return _math_op(math.tanh, fastmath, a, loc=loc, ip=ip) __all__ = [ diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/__init__.py b/python/CuTeDSL/cutlass/cute/nvgpu/__init__.py index c3e704421..3a72a7eae 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/__init__.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/__init__.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/common.py b/python/CuTeDSL/cutlass/cute/nvgpu/common.py index 053667e96..5cbe508d5 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/common.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/common.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py index 5dd3f3bc6..8337397a1 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py index b9c167b2a..e5250886b 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py index ec6789774..e02cd1469 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py b/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py index 192490571..106defcfd 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py index 74d1ce7fc..9d7d4ec11 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py index 0f165af6e..562affb1e 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py index 760f05d34..d9d545b19 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA @@ -105,27 +105,17 @@ def make_smem_layout_atom( SmemLayoutAtomKind.MN_SW128_32B, ): # M/N-major layout - return core.make_composed_layout( - sw, - 0, - core.make_layout( - (num_contiguous_elems, 8), stride=(1, num_contiguous_elems) - ), - loc=loc, - ip=ip, + outer = core.make_layout( + (num_contiguous_elems, 8), stride=(1, num_contiguous_elems), loc=loc, ip=ip ) else: # K-major layout - return core.make_composed_layout( - sw, - 0, - core.make_layout( - (8, num_contiguous_elems), stride=(num_contiguous_elems, 1) - ), - loc=loc, - ip=ip, + outer = core.make_layout( + (8, num_contiguous_elems), stride=(num_contiguous_elems, 1), loc=loc, ip=ip ) + return core.make_composed_layout(sw, 0, outer, loc=loc, ip=ip) + @overload def tile_to_mma_shape( diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py index f81bd21b3..273253a22 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py b/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py index 6d1e30344..65eacac44 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/warp/copy.py b/python/CuTeDSL/cutlass/cute/nvgpu/warp/copy.py index baff48395..0f707b28d 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/warp/copy.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/warp/copy.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py b/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py index 781128b64..c08438527 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/__init__.py b/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/__init__.py index 106c82e4d..80dc24441 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/__init__.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/__init__.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/helpers.py b/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/helpers.py index 6098b7f3b..05b8a2d8d 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/helpers.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/helpers.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py b/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py index 707625f02..bf5d7110d 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/runtime.py b/python/CuTeDSL/cutlass/cute/runtime.py index 9a68a2d13..da1a34d20 100644 --- a/python/CuTeDSL/cutlass/cute/runtime.py +++ b/python/CuTeDSL/cutlass/cute/runtime.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/tensor.py b/python/CuTeDSL/cutlass/cute/tensor.py index 3196e6b22..587ee920b 100644 --- a/python/CuTeDSL/cutlass/cute/tensor.py +++ b/python/CuTeDSL/cutlass/cute/tensor.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA @@ -426,9 +426,10 @@ class _Tensor(Tensor): return _cute_ir.get_layout(self.value, loc=loc, ip=ip) @property + @dsl_user_op @lru_cache_ir() - def shape(self) -> Shape: - return self.layout.shape + def shape(self, *, loc=None, ip=None) -> Shape: + return self.layout.shape_method(loc=loc, ip=ip) @property @lru_cache_ir() @@ -1011,7 +1012,9 @@ def recast_tensor( src_iter = recast_ptr(src.iterator, dtype=dtype, loc=loc, ip=ip) src_layout = recast_layout(dst_width, src_width, src.layout, loc=loc, ip=ip) - return make_tensor(src_iter, src_layout, loc=loc, ip=ip) + return type(src)( + make_tensor(src_iter, src_layout, loc=loc, ip=ip), dtype=dtype, loc=loc, ip=ip + ) @dsl_user_op @@ -1344,7 +1347,23 @@ class TensorSSA(cutlass_arith.ArithValue): if issubclass(rhs.dtype, Integer): rhs_val = rhs_val.with_signedness(rhs.dtype.signed) - res_vect = op(lhs_val, rhs_val) + # Use ArithValue's operator method directly to avoid recursion + # through TensorSSA's __add__/__sub__/etc. when op() dispatches + # back to the subclass method + if op.__name__ == "_min": + arith_op = cutlass_arith._min + elif op.__name__ == "_max": + arith_op = cutlass_arith._max + elif op in (operator.and_, operator.or_): + arith_op_name = f"__{op.__name__}_" + arith_op = getattr(cutlass_arith.ArithValue, arith_op_name) + else: + arith_op_name = f"__{op.__name__}__" + arith_op = getattr(cutlass_arith.ArithValue, arith_op_name, None) + if arith_op: + res_vect = arith_op(lhs_val, rhs_val, loc=loc, ip=ip) + else: + res_vect = op(lhs_val, rhs_val) res = TensorSSA(res_vect, lhs._shape, res_type) return res diff --git a/python/CuTeDSL/cutlass/cute/testing.py b/python/CuTeDSL/cutlass/cute/testing.py index 3cdcd84cc..618522a1b 100644 --- a/python/CuTeDSL/cutlass/cute/testing.py +++ b/python/CuTeDSL/cutlass/cute/testing.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/tuple.py b/python/CuTeDSL/cutlass/cute/tuple.py index 15c4114e2..66a334339 100644 --- a/python/CuTeDSL/cutlass/cute/tuple.py +++ b/python/CuTeDSL/cutlass/cute/tuple.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cute/typing.py b/python/CuTeDSL/cutlass/cute/typing.py index 1cc1af2db..77e171598 100644 --- a/python/CuTeDSL/cutlass/cute/typing.py +++ b/python/CuTeDSL/cutlass/cute/typing.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cutlass_dsl/__init__.py b/python/CuTeDSL/cutlass/cutlass_dsl/__init__.py index 077d66b76..6bc85cd4b 100644 --- a/python/CuTeDSL/cutlass/cutlass_dsl/__init__.py +++ b/python/CuTeDSL/cutlass/cutlass_dsl/__init__.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cutlass_dsl/cuda_jit_executor.py b/python/CuTeDSL/cutlass/cutlass_dsl/cuda_jit_executor.py index 72f57b8f9..a854ec4e0 100644 --- a/python/CuTeDSL/cutlass/cutlass_dsl/cuda_jit_executor.py +++ b/python/CuTeDSL/cutlass/cutlass_dsl/cuda_jit_executor.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cutlass_dsl/cuda_stream_adapter.py b/python/CuTeDSL/cutlass/cutlass_dsl/cuda_stream_adapter.py index c913624de..c4d41b6c5 100644 --- a/python/CuTeDSL/cutlass/cutlass_dsl/cuda_stream_adapter.py +++ b/python/CuTeDSL/cutlass/cutlass_dsl/cuda_stream_adapter.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cutlass_dsl/cutlass.py b/python/CuTeDSL/cutlass/cutlass_dsl/cutlass.py index c414899f3..b62e3a8f1 100644 --- a/python/CuTeDSL/cutlass/cutlass_dsl/cutlass.py +++ b/python/CuTeDSL/cutlass/cutlass_dsl/cutlass.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA @@ -16,6 +16,7 @@ regarding to that dialect. # Local module imports from types import GenericAlias, SimpleNamespace, UnionType +from typing_extensions import deprecated from typing import ( Callable, Union, @@ -103,6 +104,25 @@ from .cutlass_ast_decorators import ( from ..base_dsl.runtime.jit_arg_adapters import JitArgAdapterRegistry +# ============================================================================= +# Cutlass DSL Device Info +# ============================================================================= + +# Contains a map of SM architecture to shared memory capacity in bytes +SMEM_CAPACITY_MAP = { + "sm_121": (100 - 1) * 1024, + "sm_120": (100 - 1) * 1024, + "sm_110": (228 - 1) * 1024, + "sm_103": (228 - 1) * 1024, + "sm_101": (228 - 1) * 1024, + "sm_100": (228 - 1) * 1024, + "sm_90": (228 - 1) * 1024, + "sm_89": (100 - 1) * 1024, + "sm_86": (100 - 1) * 1024, + "sm_87": (164 - 1) * 1024, + "sm_80": (164 - 1) * 1024, +} + # ============================================================================= # Cutlass DSL Base Abstract Class # ============================================================================= @@ -812,6 +832,27 @@ class CutlassBaseDSL(BaseDSL): ) cfg.smem = const(cfg.smem) + # Warn user if shared memory exceed arch max + # Currently runtime only show 'CUDA_ERROR_INVALID_VALUE' error which is not useful + arch = self.dsl.get_arch_enum() + arch_str = f"sm_{arch.major}{arch.minor}" + if arch_str in SMEM_CAPACITY_MAP: + arch_smem = SMEM_CAPACITY_MAP[arch_str] + smem_msg = ( + f"\nError: kernel '{kernelSym}' launch shared memory " + f"exceeds current GPU arch {arch} allowed. " + f"Allocated: {{}} bytes. Max: {arch_smem} bytes.\n\n" + ) + if_generate( + arch_smem < cfg.smem, + lambda: cute.print_([cfg.smem], fmt=smem_msg), + loc=loc, + ) + else: + raise DSLRuntimeError( + f"Lack smem capacity info for GPU arch {arch}." + ) + async_deps = cfg.async_deps if not isinstance(cfg.async_deps, (list, tuple)): async_deps = [cfg.async_deps] @@ -1175,6 +1216,7 @@ class KernelLauncher: self.func_kwargs = func_kwargs self._name_prefix = func_kwargs.pop("_name_prefix", None) + self._launch_name = None self._check_func_args(funcBody, *func_args, **func_kwargs) @@ -1192,7 +1234,10 @@ class KernelLauncher: cause=e, ) - def smem_usage(self) -> int: + @deprecated( + "`smem_usage()` is deprecated, use public API `arch.dynamic_smem_size()` instead." + ) + def smem_usage(self) -> Int32: """ Check smem usage for this kernel, only available after `launch` """ @@ -1215,6 +1260,7 @@ class KernelLauncher: ret, name = kernel_generator(*self.func_args, **self.func_kwargs, config=config) self.dsl.kernel_info[name] = kernel_attrs + self._launch_name = name return ret.launch_op_ret def __call__(self, *args, **kwargs): @@ -1443,7 +1489,7 @@ def _minmax(op, *args, loc=None, ip=None): for x in xs: emitter = getattr(cutlass_arith, f"_{op.__name__}") if not (is_dynamic_expression(res) or is_dynamic_expression(x)): - res = emitter(op(res), op(x)) + res = emitter(op(res), op(x), loc=loc, ip=ip) elif ( hasattr(res, "type") and hasattr(x, "type") @@ -1473,7 +1519,7 @@ def _minmax(op, *args, loc=None, ip=None): rhs_val = rhs.value.with_signedness(rhs.signed) else: rhs_val = rhs.value - res = res_type(emitter(lhs_val, rhs_val), loc=loc, ip=ip) + res = res_type(emitter(lhs_val, rhs_val, loc=loc, ip=ip), loc=loc, ip=ip) x = res else: raise DSLNotImplemented(f"{type(args)} is not supported") diff --git a/python/CuTeDSL/cutlass/cutlass_dsl/cutlass_ast_decorators.py b/python/CuTeDSL/cutlass/cutlass_dsl/cutlass_ast_decorators.py index 214eb710b..26cc159f9 100644 --- a/python/CuTeDSL/cutlass/cutlass_dsl/cutlass_ast_decorators.py +++ b/python/CuTeDSL/cutlass/cutlass_dsl/cutlass_ast_decorators.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/cutlass_dsl/tvm_ffi_provider.py b/python/CuTeDSL/cutlass/cutlass_dsl/tvm_ffi_provider.py index 14f564d19..b63c863be 100644 --- a/python/CuTeDSL/cutlass/cutlass_dsl/tvm_ffi_provider.py +++ b/python/CuTeDSL/cutlass/cutlass_dsl/tvm_ffi_provider.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/impl_utils.py b/python/CuTeDSL/cutlass/impl_utils.py index fef00be1b..29bd22786 100644 --- a/python/CuTeDSL/cutlass/impl_utils.py +++ b/python/CuTeDSL/cutlass/impl_utils.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/jax/__init__.py b/python/CuTeDSL/cutlass/jax/__init__.py index 082c3818d..286ce88f5 100644 --- a/python/CuTeDSL/cutlass/jax/__init__.py +++ b/python/CuTeDSL/cutlass/jax/__init__.py @@ -3,25 +3,48 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA # is strictly prohibited. from functools import cache +import logging + +logger = logging.getLogger(__name__) + +# This is the minimum JAX version that will work with CuTeDSL JAX extensions. +# +# See the following pages for details on JAX versioning: +# - https://docs.jax.dev/en/latest/jep/25516-effver.html +# - https://docs.jax.dev/en/latest/jep/9419-jax-versioning.html +CUTE_DSL_MIN_SUPPORTED_JAX_VERSION = (0, 5, 0) + @cache def is_available(): - """Returns true of Jax support is enabled.""" + """Returns true if JAX extensions are supported and available.""" try: import jax + import jax.numpy # Also verify jax.numpy is available + except ImportError: + logger.debug( + "CuTeDSL JAX extensions are not available because JAX was not found or could not be imported." + ) + return False - _HAVE_JAX = True - except ImportError as e: - _HAVE_JAX = False + if not ( + hasattr(jax.version, "__version_info__") + and jax.version.__version_info__ >= CUTE_DSL_MIN_SUPPORTED_JAX_VERSION + ): + logger.debug( + f"Your installed JAX v{jax.__version__} too old and not supported by CuTeDSL JAX extensions.\n" + "Please upgrade to the latest version." + ) + return False - return _HAVE_JAX + return True if is_available(): @@ -29,6 +52,8 @@ if is_available(): from .types import ( jax_to_cutlass_dtype, cutlass_to_jax_dtype, + jax_to_cutlass_layout_order, + cutlass_to_jax_layout_order, from_dlpack, JaxArray, TensorSpec, @@ -41,6 +66,7 @@ if is_available(): find_cute_dsl_runtime_library, register_ffi, is_ffi_registered, + get_cutlass_call_ffi_version, ) from . import testing @@ -48,18 +74,24 @@ if is_available(): TensorMode = TensorSpec __all__ = [ + "CUTE_DSL_MIN_SUPPORTED_JAX_VERSION", "cutlass_call", "jax_to_cutlass_dtype", "cutlass_to_jax_dtype", + "jax_to_cutlass_layout_order", + "cutlass_to_jax_layout_order", "from_dlpack", "JaxArray", "TensorSpec", "TensorMode", "release_compile_cache", "get_export_disabled_safety_checks", + "is_ffi_registered", + "register_ffi", + "get_cutlass_call_ffi_version", "is_available", "testing", ] else: # export is_available check for callers or tests. - __all__ = ["is_available"] + __all__ = ["CUTE_DSL_MIN_SUPPORTED_JAX_VERSION", "is_available"] diff --git a/python/CuTeDSL/cutlass/jax/compile.py b/python/CuTeDSL/cutlass/jax/compile.py index fa2f014b2..609180cbf 100644 --- a/python/CuTeDSL/cutlass/jax/compile.py +++ b/python/CuTeDSL/cutlass/jax/compile.py @@ -3,20 +3,16 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -import os import gc -import ctypes -import inspect -from typing import Any, Callable, Optional, Sequence +from typing import Any from dataclasses import dataclass from functools import partial -from pathlib import Path import time import logging @@ -27,17 +23,13 @@ import cuda.bindings.driver as cuda import jax import jax.numpy as jnp -import jaxlib from .types import ( jax_to_cutlass_dtype, - from_dlpack, - JaxArray, JaxArrayList, TensorSpec, JaxTracedArray, DEFAULT_CUTLASS_DEVICE_MEMSPACE, - DEFAULT_CUTLASS_DEVICE_BUFFER_ALIGNMENT, ) import cutlass @@ -90,6 +82,7 @@ class FunctionSpec: leaf.spec.layout, leaf.spec.mode, leaf.get_static_flag(self.use_static_tensors), + leaf.spec.divisibility, ) for leaf in self.in_args ] @@ -102,6 +95,7 @@ class FunctionSpec: leaf.spec.layout, leaf.spec.mode, leaf.get_static_flag(self.use_static_tensors), + leaf.spec.divisibility, ) for leaf in self.out_args ] diff --git a/python/CuTeDSL/cutlass/jax/ffi.py b/python/CuTeDSL/cutlass/jax/ffi.py index 087bfd4d5..5aee47fcc 100644 --- a/python/CuTeDSL/cutlass/jax/ffi.py +++ b/python/CuTeDSL/cutlass/jax/ffi.py @@ -3,16 +3,15 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from typing import Sequence +from typing import Sequence, Optional from pathlib import Path from functools import cache -import os import logging import ctypes @@ -27,22 +26,51 @@ logger = logging.getLogger(__name__) _CUTE_DSL_RUNTIME_LIBRARY_NAME = "cute_dsl_runtime" -_CUTLASS_CALL_TARGETS = { +# V1 targets for older jax clients +_CUTLASS_CALL_TARGETS_V1 = { "CuteDSLRT_NvJaxCutlassCall": { - "execute": "CuteDSLRT_NvJaxCutlassCallExecute", - "prepare": "CuteDSLRT_NvJaxCutlassCallPrepare", + "prepare": "CuteDSLRT_NvJaxCutlassCallPrepare_v1", + "execute": "CuteDSLRT_NvJaxCutlassCallExecute_v1", }, "CuteDSLRT_NvJaxCutlassCallNoCudaGraph": { - "execute": "CuteDSLRT_NvJaxCutlassCallExecuteNoCudaGraph", - "prepare": "CuteDSLRT_NvJaxCutlassCallPrepare", + "prepare": "CuteDSLRT_NvJaxCutlassCallPrepare_v1", + "execute": "CuteDSLRT_NvJaxCutlassCallExecuteNoCudaGraph_v1", }, } +# V2 targets for newer jax clients supporting stateful FFI calls. +_JAX_FFI_V2_MIN_VERSION = (0, 9, 1) +_CUTLASS_CALL_TARGETS_V2 = { + "CuteDSLRT_NvJaxCutlassCall": { + "execute": "CuteDSLRT_NvJaxCutlassCallExecute_v2", + "instantiate": "CuteDSLRT_NvJaxCutlassCallInstantiate_v2", + "prepare": "CuteDSLRT_NvJaxCutlassCallPrepare_v2", + }, + "CuteDSLRT_NvJaxCutlassCallNoCudaGraph": { + "execute": "CuteDSLRT_NvJaxCutlassCallExecuteNoCudaGraph_v2", + "instantiate": "CuteDSLRT_NvJaxCutlassCallInstantiate_v2", + "prepare": "CuteDSLRT_NvJaxCutlassCallPrepare_v2", + }, +} +_CUTLASS_CALL_TYPES_V2 = { + "CuteDSLRT_NvJaxCutlassCallTypes": { + "type_id": "CuteDSLRT_NvJaxCutlassCallStateTypeId_v2", + "type_info": "CuteDSLRT_NvJaxCutlassCallStateTypeInfo_v2", + } +} -def get_cutlass_call_ffi_name(allow_cuda_graph): + +def get_cutlass_call_ffi_version() -> int: + """Returns the FFI API version based on JAX version.""" + if jax.version.__version_info__ >= _JAX_FFI_V2_MIN_VERSION: + return 2 + else: + return 1 + + +def get_cutlass_call_ffi_name(allow_cuda_graph: bool) -> str: """Returns the FFI target to call when running cutlass_call functions.""" - disable_cuda_graph = not allow_cuda_graph - if not disable_cuda_graph: + if allow_cuda_graph: return "CuteDSLRT_NvJaxCutlassCall" else: return "CuteDSLRT_NvJaxCutlassCallNoCudaGraph" @@ -50,14 +78,14 @@ def get_cutlass_call_ffi_name(allow_cuda_graph): def get_export_disabled_safety_checks() -> Sequence[jax.export.DisabledSafetyCheck]: """Returns jax.export.DisabledSafetyCheck to allow cutlass_call kernels.""" - checks = [] - for target in _CUTLASS_CALL_TARGETS: - checks.append(jax.export.DisabledSafetyCheck.custom_call(target)) - return tuple(checks) + targets = set(_CUTLASS_CALL_TARGETS_V1.keys()) | set( + _CUTLASS_CALL_TARGETS_V2.keys() + ) + return tuple([jax.export.DisabledSafetyCheck.custom_call(t) for t in targets]) @cache -def find_cute_dsl_runtime_library(): +def find_cute_dsl_runtime_library() -> Optional[str]: """Searches for the CuTeDSL runtime library.""" dsl = CuTeDSL._get_dsl() candidate_libs = [] @@ -85,7 +113,10 @@ def find_cute_dsl_runtime_library(): candidate_libs.extend(dsl_libs) except Exception as e: - logger.debug(f"Failed to locate libraries due to an exception:", e) + logger.debug( + f"Failed to locate {_CUTE_DSL_RUNTIME_LIBRARY_NAME} library: {e}", + exc_info=True, + ) for lib in candidate_libs: if lib.endswith(f"{_CUTE_DSL_RUNTIME_LIBRARY_NAME}.so"): @@ -97,8 +128,12 @@ def find_cute_dsl_runtime_library(): _FFI_CALLS_REGISTERED = False -def register_ffi(): - """Registers custom calls with Jax/XLA runtime.""" +def register_ffi(ffi_version: int = get_cutlass_call_ffi_version()): + """Registers custom calls with Jax/XLA runtime. + + A specific version can be requested using `ffi_version` argument. Attempting + to register non default FFI versions may not work with your specific JAX. + """ global _FFI_CALLS_REGISTERED if _FFI_CALLS_REGISTERED: return @@ -112,27 +147,36 @@ def register_ffi(): lib = ctypes.CDLL(runtime_library) - def _capsule(funcptr): - destructor = ctypes.CFUNCTYPE(None, ctypes.py_object) - builder = ctypes.pythonapi.PyCapsule_New - builder.restype = ctypes.py_object - builder.argtypes = (ctypes.c_void_p, ctypes.c_char_p, destructor) - return builder(funcptr, None, destructor(0)) - def _register_ffi_targets(lib, targets): for target_name, target in targets.items(): handler = {} for stage, fn_name in target.items(): fn = getattr(lib, fn_name) fn.restype = ctypes.c_void_p - handler[stage] = _capsule(fn) + handler[stage] = jax.ffi.pycapsule(fn) logger.debug(f"Registering ffi handler: {target_name}, {handler}") - jax.ffi.register_ffi_target( - target_name, handler["execute"], platform="CUDA" - ) + jax.ffi.register_ffi_target(target_name, handler, platform="CUDA") - # Register the custom FFI targets - _register_ffi_targets(lib, _CUTLASS_CALL_TARGETS) + def _register_ffi_types(lib, types): + for type_name, type_dict_targets in types.items(): + type_dict = {} + for field, fn_name in type_dict_targets.items(): + fn = getattr(lib, fn_name) + fn.restype = ctypes.c_void_p + type_dict[field] = jax.ffi.pycapsule(fn()) + logger.debug(f"Registering ffi type: {type_name}, {type_dict}") + jax.ffi.register_ffi_type(type_name, type_dict, platform="CUDA") + + # Register the custom FFI targets. + match ffi_version: + case 1: + _register_ffi_targets(lib, _CUTLASS_CALL_TARGETS_V1) + # no types for v1 + case 2: + _register_ffi_types(lib, _CUTLASS_CALL_TYPES_V2) + _register_ffi_targets(lib, _CUTLASS_CALL_TARGETS_V2) + case _: + raise ValueError(f"Invalid FFI version {ffi_version}") _FFI_CALLS_REGISTERED = True diff --git a/python/CuTeDSL/cutlass/jax/primitive.py b/python/CuTeDSL/cutlass/jax/primitive.py index f9d838ce5..0dab44306 100644 --- a/python/CuTeDSL/cutlass/jax/primitive.py +++ b/python/CuTeDSL/cutlass/jax/primitive.py @@ -3,33 +3,28 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from typing import Any, Union, Sequence, Callable -from functools import partial +from typing import Any, Sequence, Callable import logging -import os -import cuda.bindings.driver as cuda -import jax, jax.numpy as jnp +import jax import jax.extend from jax.interpreters import mlir from jax._src.interpreters import ad from jax._src.interpreters import batching -from jax._src import ffi -from jax.tree import flatten, unflatten -import cutlass from .compile import get_or_compile_kernel, build_function_spec -from .types import row_major_layout, default_tensor_spec, TensorSpec +from .types import cutlass_to_jax_layout_order, default_tensor_spec, TensorSpec from .ffi import get_cutlass_call_ffi_name, is_ffi_registered, register_ffi + logger = logging.getLogger(__name__) cutlass_call_inner_p = jax.extend.core.Primitive("cutlass_call_inner") @@ -39,7 +34,7 @@ cutlass_call_inner_p.multiple_results = True def cutlass_call( fn: Callable[..., None], *, - output_shape_dtype: Any, + output_shape_dtype: Any = None, input_spec: Any = None, output_spec: Any = None, input_mode: Any = None, @@ -50,29 +45,64 @@ def cutlass_call( use_static_tensors=False, **kwargs, ): - """Creates a callable that invokes a @cute.jit function. + """Create a callable that invokes a ``@cute.jit`` function from JAX. + + Returns a callable that accepts JAX arrays and dispatches to *fn* as part + of a ``jax.jit``-compiled computation. The kernel is compiled once on the + first call and cached for subsequent invocations with the same shapes and + specs. + + Example:: + + @cute.jit + def my_kernel(stream, A, B, C, D): + ... + + @jax.jit + def run(a, b): + return cutlass_call( + my_kernel, + output_shape_dtype=( + jax.ShapeDtypeStruct(a.shape, a.dtype), + jax.ShapeDtypeStruct(b.shape, b.dtype), + ), + )(a, b) + + c, d = run(a, b) Args: - fn: A @cute.jit decorated function that launches a cutlass kernel. - output_shape_dtype: A pytree representing the shape and dtype of the output buffers. - input_output_aliases: Optional mapping of input to output aliases. Positions are specified assuming - a flattened input and output pytree. - input_spec: Specifies a cute.Tensor dimension order for input tensors. If None then the order - will assume the corresponding layout order. - output_spec: Specifies a cute.Tensor dimension order for output tensors. If None then the order - will assume the corresponding layout order. - input_mode: Legacy alias for input_spec. This parameter may be removed in future versions. - output_spec: Legacy alias for output_spec. This parameter may be removed in future versions. - allow_cuda_graph: If false will prevent XLA from building a cuda graph of for this call. - compile_options: Optional compiler arguments to pass into cute.compile. - use_static_tensors: If True, tensor shapes and strides are treated as constexpr values by - default. This can improve performance through compiler specialization but may not work - properly with all kernels. Specific tensors may be marked static or dynamic using the mode - and override this flag. - kwargs: Optional constexpr parameters to pass into the kernel fn. + fn: A ``@cute.jit``-decorated function with the signature + ``(stream, *inputs, *outputs, **kwargs)``. + output_shape_dtype: A pytree of :class:`jax.ShapeDtypeStruct` (or + objects with ``.shape`` and ``.dtype`` attributes) describing each + output buffer. + input_spec: A :class:`TensorSpec` or list thereof providing + layout/mode/divisibility hints for input tensors. ``None`` infers + defaults from each array. + output_spec: Same as *input_spec* but applied to output tensors. + input_output_aliases: ``{input_index: output_index}`` mapping that + allows an input buffer to alias an output, avoiding an extra copy. + Indices are into the flattened input and output pytrees. + allow_cuda_graph: If ``False``, prevents XLA from capturing this call + in a CUDA graph. Defaults to ``True``. + compile_options: Optional dict of compiler flags forwarded to + ``cute.compile``. + use_static_tensors: If ``True``, tensor shapes and strides are baked in + as compile-time constants, improving performance when shapes are + fixed across calls. Defaults to ``False``. + **kwargs: Additional keyword arguments forwarded to *fn* as compile-time + constants. - Note: This API is experimental and subject to change! + Returns: + A callable ``(*arrays) -> output_pytree`` that can be used inside + ``jax.jit``. + + Note: + This API is experimental and subject to change. """ + if output_shape_dtype is None: + raise ValueError("'output_shape_dtype' must be specified.") + output_shape_dtype = jax.tree.map( lambda leaf: jax.ShapeDtypeStruct(leaf.shape, leaf.dtype), output_shape_dtype ) @@ -107,22 +137,84 @@ def cutlass_call( ) -def _normalize_tensor_spec(value: Any): - if value is None: - return [None] - elif isinstance(value, (tuple, list)): - if isinstance(value[0], int): # single tuple of modes - return TensorSpec(mode=tuple(value)) +def _is_spec_leaf(x: Any) -> bool: + """Return True if *x* should be treated as a leaf when traversing a spec pytree. + + Stops traversal at ``TensorSpec`` and ``None`` (both are valid leaf specs) and at + bare integer sequences (the legacy mode-spec shorthand). Everything else is + treated as a pytree container and recursed into by ``jax.tree.leaves``. + """ + if x is None or isinstance(x, TensorSpec): + return True + # Legacy: a bare sequence of ints represents a single TensorSpec(mode=...). + # Check *all* elements so that mixed sequences like (1, TensorSpec()) are NOT + # mistaken for a mode spec and instead cause a TypeError below. + if isinstance(x, (list, tuple)) and bool(x) and all(isinstance(i, int) for i in x): + return True + return False + + +def _normalize_tensor_spec(value: Any) -> list[TensorSpec | None]: + """Normalize a spec pytree into a flat list of ``TensorSpec | None`` entries. + + *value* may be any JAX pytree whose leaves are ``TensorSpec``, ``None``, or a + bare integer sequence (legacy shorthand for ``TensorSpec(mode=...)``). Dict and + other non-list/tuple pytree containers are supported via ``jax.tree.leaves``. + + Note: ``TensorSpec`` is itself a JAX-registered dataclass with all-static fields, + so traversal *must* use the ``is_leaf`` predicate to stop at ``TensorSpec`` nodes + rather than recursing into them (which would yield no children and silently drop + the spec). + """ + leaves = jax.tree.leaves(value, is_leaf=_is_spec_leaf) + result = [] + for leaf in leaves: + if leaf is None or isinstance(leaf, TensorSpec): + result.append(leaf) + elif isinstance(leaf, (list, tuple)): + # Legacy: bare int sequence → TensorSpec(mode=...) + result.append(TensorSpec(mode=tuple(leaf))) else: - flat, _ = jax.tree.flatten( - [_normalize_tensor_spec(x) for x in value], - is_leaf=lambda x: x is None or isinstance(x, TensorSpec), + raise TypeError( + f"Unexpected value for TensorSpec: {leaf!r} ({type(leaf).__name__})" + ) + return result + + +def _resolve_spec_flat(spec: Any, tensors: list) -> tuple[TensorSpec, ...]: + """Normalize *spec* and fill any ``None`` slots with defaults inferred from *tensors*.""" + if spec is None: + return tuple(default_tensor_spec(t) for t in tensors) + specs = list(_normalize_tensor_spec(spec)) + if len(specs) != len(tensors): + raise ValueError( + f"Must have the same number of specs ({len(specs)}) as tensors ({len(tensors)})." + ) + return tuple( + default_tensor_spec(t) if s is None else s for s, t in zip(specs, tensors) + ) + + +def _validate_specs(label: str, tensors: list, specs: tuple[TensorSpec, ...]) -> None: + """Validate that each spec's rank-dependent fields match the corresponding tensor shape.""" + for idx, (tensor, spec) in enumerate(zip(tensors, specs)): + ndim = len(tensor.shape) + if spec.layout is not None and len(spec.layout) != ndim: + raise ValueError( + f"{label} #{idx} has invalid layout {spec.layout} for shape {tensor.shape}." + ) + if spec.mode is not None and len(spec.mode) != ndim: + raise ValueError( + f"{label} #{idx} has invalid mode {spec.mode} for shape {tensor.shape}." + ) + if ( + spec.divisibility is not None + and not isinstance(spec.divisibility, int) + and len(spec.divisibility) != ndim + ): + raise ValueError( + f"{label} #{idx} has invalid divisibility {spec.divisibility} for shape {tensor.shape}." ) - return flat - elif isinstance(value, TensorSpec): - return [value] - else: - raise TypeError(f"Unexpected value for TensorMode {value} {type(value)}") def _cutlass_call_impl( @@ -137,70 +229,21 @@ def _cutlass_call_impl( use_static_tensors, **kwargs, ): + # A single ShapeDtypeStruct means one output; a sequence means multiple. multiple_results = isinstance(output_shape_dtype, Sequence) if not multiple_results: output_shape_dtype = (output_shape_dtype,) output_shape_dtype_flat, output_tree = jax.tree.flatten(output_shape_dtype) - @partial(jax.jit, inline=True) + @jax.jit def call_wrapper(*args): args_flat, args_tree = jax.tree.flatten(args) - if input_spec is None: - input_spec_flat = tuple(default_tensor_spec(x) for x in args_flat) - else: - input_spec_flat = _normalize_tensor_spec(input_spec) - for idx, (spec, arg) in enumerate(zip(input_spec_flat, args_flat)): - if spec is None: - input_spec_flat[idx] = default_tensor_spec(arg) - input_spec_flat = tuple(input_spec_flat) + input_spec_flat = _resolve_spec_flat(input_spec, args_flat) + output_spec_flat = _resolve_spec_flat(output_spec, output_shape_dtype_flat) - if output_spec is None: - output_spec_flat = tuple( - default_tensor_spec(x) for x in output_shape_dtype_flat - ) - else: - output_spec_flat = _normalize_tensor_spec(output_spec) - for idx, (spec, arg) in enumerate( - zip(output_spec_flat, output_shape_dtype_flat) - ): - if spec is None: - output_spec_flat[idx] = default_tensor_spec(arg) - output_spec_flat = tuple(output_spec_flat) - if len(input_spec_flat) != len(args_flat): - raise ValueError( - f"Must has same number of input modes ({len(input_spec_flat)}) as input arrays ({len(args_flat)})." - ) - - if len(output_spec_flat) != len(output_shape_dtype_flat): - raise ValueError( - f"Must has same number of output modes ({len(output_spec_flat)}) as output arrays ({len(output_shape_dtype_flat)})." - ) - - # Validate dynamic mode settings match whatever static shape - # information we got as input. - for idx, (arg, spec) in enumerate(zip(args_flat, input_spec_flat)): - if spec.layout is not None and len(spec.layout) != len(arg.shape): - raise ValueError( - f"Input #{idx} has invalid layout {spec.layout} for shape {arg.shape}." - ) - if spec.mode is not None and len(spec.mode) != len(arg.shape): - raise ValueError( - f"Input #{idx} has invalid mode {spec.mode} for shape {arg.shape}." - ) - - for idx, (arg, spec) in enumerate( - zip(output_shape_dtype_flat, output_spec_flat) - ): - if spec.layout is not None and len(spec.layout) != len(arg.shape): - raise ValueError( - f"Output #{idx} has invalid layout {spec.layout} for shape {arg.shape}." - ) - - if spec.mode is not None and len(spec.mode) != len(arg.shape): - raise ValueError( - f"Output #{idx} has invalid mode {spec.mode} for shape {arg.shape}." - ) + _validate_specs("Input", args_flat, input_spec_flat) + _validate_specs("Output", output_shape_dtype_flat, output_spec_flat) output_flat = cutlass_call_inner_p.bind( *args_flat, @@ -208,8 +251,8 @@ def _cutlass_call_impl( args_tree=args_tree, output_shape_dtype_flat=tuple(output_shape_dtype_flat), output_tree=output_tree, - input_spec_flat=tuple(input_spec_flat), - output_spec_flat=tuple(output_spec_flat), + input_spec_flat=input_spec_flat, + output_spec_flat=output_spec_flat, input_output_aliases=tuple(input_output_aliases.items()), allow_cuda_graph=allow_cuda_graph, compile_options=compile_options, @@ -264,10 +307,17 @@ def cutlass_call_inner_p_impl( register_ffi() call_name = get_cutlass_call_ffi_name(allow_cuda_graph) + + # Convert layout from CuTeDSL to JAX order as ffi_call expects this. + input_layouts = [cutlass_to_jax_layout_order(s.layout) for s in input_spec_flat] + output_layouts = [cutlass_to_jax_layout_order(s.layout) for s in output_spec_flat] + fun = jax.ffi.ffi_call( call_name, result_shape_dtypes=output_shape_dtype_flat, input_output_aliases=dict(spec.input_output_aliases), + input_layouts=input_layouts, + output_layouts=output_layouts, ) return fun(*args_flat, module=kernel.module, key=kernel.fingerprint) diff --git a/python/CuTeDSL/cutlass/jax/testing.py b/python/CuTeDSL/cutlass/jax/testing.py index 781671a23..9c5d9a0ef 100644 --- a/python/CuTeDSL/cutlass/jax/testing.py +++ b/python/CuTeDSL/cutlass/jax/testing.py @@ -3,13 +3,12 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from functools import partial import jax import jax.numpy as jnp @@ -17,6 +16,7 @@ import jax.numpy as jnp import cutlass.cute as cute from cutlass.cutlass_dsl import dsl_user_op + def reorder_modes(src: str, target: str) -> tuple[int, ...]: """Computes the mode given a source and target order.""" src = tuple(src) @@ -88,6 +88,7 @@ def get_gemm_shape_from_tensors( n = b.shape[0] return (m, n, k, l) + def create_tensor( shape, dtype, key, *, minval=-2.0, maxval=2.0, fill_value=None, fill_arange=False ): diff --git a/python/CuTeDSL/cutlass/jax/types.py b/python/CuTeDSL/cutlass/jax/types.py index 0fa414ed7..56e5d2aca 100644 --- a/python/CuTeDSL/cutlass/jax/types.py +++ b/python/CuTeDSL/cutlass/jax/types.py @@ -3,24 +3,15 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from typing import Type, Optional, Sequence, Union, Callable, Any, TypeVar -import sys -import ctypes -import math -import inspect +from typing import Sequence from dataclasses import dataclass, field -from functools import partial, reduce -from operator import mul -from itertools import chain -from typing import Annotated -import cuda.bindings.driver as cuda import jax import jax.numpy as jnp @@ -28,13 +19,13 @@ import jax.numpy as jnp import cutlass import cutlass.cute as cute from cutlass.cute.runtime import from_dlpack as _from_dlpack -from cutlass.cute import AddressSpace, Numeric, IntTuple +from cutlass.cute import AddressSpace from cutlass._mlir import ir from cutlass._mlir.dialects import llvm, arith -import cutlass._mlir.dialects.cute as _cute_ir JAX_DTYPE_TO_CUTLASS_DTYPE = { jnp.bool.dtype: cutlass.Boolean, + jnp.int4.dtype: cutlass.Int4, jnp.int8.dtype: cutlass.Int8, jnp.int16.dtype: cutlass.Int16, jnp.int32.dtype: cutlass.Int32, @@ -65,38 +56,69 @@ DEFAULT_CUTLASS_DEVICE_BUFFER_ALIGNMENT = 256 @jax.tree_util.register_dataclass @dataclass(frozen=True) class TensorSpec: - """Provides a specification of cute.Tensor modes and additional metadata about - dynamic/static shapes for compilation. + """Specifies the layout and metadata for a JAX array passed to a CuTe kernel. - Arguments: - layout : Specifies the position of stries as they relate to the framework - tensor (S0, S1, ... SN) - mode : Specifies the position of each mode in the tensor (M0, M1, ... MN) - static : Specifies the tensor shape is represented as static constexpr. - ptr_assumed_align: Specifies the pointer alignment. + TensorSpec controls how a JAX array's dimensions are mapped to a cute.Tensor + during jit lowering, including stride ordering, mode permutation, and whether + shapes/strides are compiled as static constants. + + Attributes: + layout: A minor-to-major stride ordering in CuTeDSL convention. ``layout[i]`` + gives the stride rank of dimension ``i``, where rank 0 means the smallest + (innermost) stride. For example, row-major order for a 3-D tensor is + ``(2, 1, 0)``. If ``None``, row-major is assumed. Use + :func:`jax_to_cutlass_layout_order` to convert from JAX's major-to-minor + convention. + mode: A permutation that maps the stride-ordered dimensions to the mode + positions of the resulting ``cute.Layout``. For example, ``mode=(2, 0, 1)`` + reorders an ``(M, K, L)`` layout into ``(K, L, M)`` mode order inside the + kernel. If ``None``, modes match the natural dimension order ``(0, 1, ..., N-1)``. + static: If ``True``, shapes and strides are compiled as static ``constexpr`` + values, which may enable additional compiler optimisations. Kernels that + do not support static shapes will raise a compile error. Must be ``False`` + when any dimension is symbolic (e.g. under ``jax.export``). + ptr_assumed_align: Assumed byte alignment of the tensor's data pointer. + Overrides the default of 256 bytes. Rarely needs to change. + divisibility: Optional per-mode divisibility hints. If a single int is passed + divisibility will be applied to the leading (stride=1) dimension only. """ - # Specifies the layout of the Jax array. If not it will be assumed that the layout - # is row major. + # Minor-to-major stride ordering in CuTeDSL convention (layout[i] = stride rank + # of dimension i, 0 = innermost). Defaults to row-major if None. layout: tuple[int, ...] | None = field(metadata=dict(static=True), default=None) - # Indicates the order of modes. If unspecified the modes will match exactly with - # the layout of the Jax tensor (e.g. row-major). Typically used to map from the - # input layout to kernel layouts (e.g. MKL/NKL/MNL). + # Permutation from stride-ordered dimensions to cute.Layout mode positions. + # Defaults to identity (0, 1, ..., N-1) if None. mode: tuple[int, ...] | None = field(metadata=dict(static=True), default=None) - # Indicates the shape and strides will be defined statically. Setting True ay enable - # additional optimization. Kernels that do not support static shapes will generate - # compile errors if this is enabled so we leave it off by default. + # If True, shapes and strides are embedded as compile-time constants. + # Must be False for symbolic/dynamic shapes (e.g. jax.export). static: bool = field(metadata=dict(static=True), default=None) - # Overrides the default pointer alignment. Generally this should not be changed - # but is left here to provide a hook. + # Assumed alignment (bytes) of the data pointer. Default matches XLA's 256-byte alignment. ptr_assumed_align: int = field( metadata=dict(static=True), default=DEFAULT_CUTLASS_DEVICE_BUFFER_ALIGNMENT ) + # Per-mode divisibility hints. + divisibility: tuple[int | None, ...] | int | None = field( + metadata=dict(static=True), default=None + ) + def row_major_layout(shaped): - """Returns a row major layout given a shaped value. + """Returns the CuTeDSL minor-to-major stride ordering for a row-major (C-contiguous) tensor. - Row major layout is (N-1, N-2, ... 1, 0) for an N-dimensional tensor. + In CuTeDSL convention, ``layout[i]`` is the stride rank of dimension ``i``, + where rank 0 denotes the innermost (stride-1) dimension. Row-major means the + last dimension is innermost, so the result is ``(N-1, N-2, ..., 1, 0)`` for an + N-dimensional tensor. + + Example:: + + row_major_layout((M, K, N)) # → (2, 1, 0) + + Args: + shaped: An object with a ``.shape`` attribute, or a shape tuple/sequence. + + Returns: + A tuple of length N representing the minor-to-major ordering. """ if hasattr(shaped, "shape"): shaped = shaped.shape @@ -104,9 +126,17 @@ def row_major_layout(shaped): def default_tensor_mode(shaped): - """Returns a default tensor mode given a shaped value. + """Returns the identity mode permutation for an N-dimensional tensor. - Default mode is (0, 1, ... N-2, N-1) for an N_dimensional tensor. + The mode permutation maps stride-ordered dimensions to ``cute.Layout`` mode + positions. The default identity ``(0, 1, ..., N-1)`` leaves the mode order + unchanged relative to the dimension order. + + Args: + shaped: An object with a ``.shape`` attribute, or a shape tuple/sequence. + + Returns: + A tuple ``(0, 1, ..., N-1)`` of length N. """ if hasattr(shaped, "shape"): shaped = shaped.shape @@ -114,14 +144,108 @@ def default_tensor_mode(shaped): def default_tensor_spec(shaped) -> TensorSpec: - """Returns a default tensor spec given a shaped value. + """Returns a :class:`TensorSpec` with row-major layout and identity mode ordering. - Default layout is (N-1, N-2, ... 1, 0) for an N-dimensional tensor. - Default mode is (0, 1, ... N-2, N-1) for an N_dimensional tensor. + Equivalent to:: + + TensorSpec(layout=(N-1, ..., 1, 0), mode=(0, 1, ..., N-1), divisibility=(D0, D1, ... DN-1)) + + This is appropriate for standard row-major (C-contiguous) JAX arrays that + do not require dimension reordering inside the kernel. + + Divisibility hints are inferred only for concrete integer dimensions. + Symbolic dimensions always produce ``None`` for their slot; pass an + explicit ``TensorSpec`` with ``divisibility`` set if you need alignment + hints for symbolic shapes. + + Args: + shaped: An object with a ``.shape`` attribute, or a shape tuple/sequence. + + Returns: + A :class:`TensorSpec` with ``layout`` set to row-major minor-to-major order + and ``mode`` set to the identity permutation. """ if hasattr(shaped, "shape"): shaped = shaped.shape - return TensorSpec(layout=row_major_layout(shaped), mode=default_tensor_mode(shaped)) + inferred = tuple(d if isinstance(d, int) else None for d in shaped) + divisibility = inferred if any(d is not None for d in inferred) else None + return TensorSpec( + layout=row_major_layout(shaped), + mode=default_tensor_mode(shaped), + divisibility=divisibility, + ) + + +def _expand_divisibility( + divisibility, order: tuple[int, ...], ndim: int +) -> tuple[int | None, ...] | None: + """Expand a divisibility spec to a full per-dimension tuple. + + A bare ``int`` is placed at the leading-dimension slot (where + ``order[i] == 0``, i.e. stride == 1) and ``None`` everywhere else. + A tuple is returned unchanged. ``None`` returns ``None``. + """ + if divisibility is None or isinstance(divisibility, tuple): + return divisibility + leading = order.index(0) + result = [None] * ndim + result[leading] = divisibility + return tuple(result) + + +def cutlass_to_jax_layout_order( + layout: Sequence[int] | None, +) -> Sequence[int] | None: + """Converts a CuTeDSL layout order (minor-to-major) to JAX layout order (major-to-minor). + + CuTeDSL uses minor-to-major ordering: ``layout[i]`` is the stride rank of + dimension ``i`` (0 = innermost). JAX uses major-to-minor ordering: position + ``j`` in the result is the dimension index of the ``j``-th outermost axis. + + Example:: + + cutlass_to_jax_layout_order((2, 1, 0)) # row-major → (0, 1, 2) + cutlass_to_jax_layout_order((0, 1, 2)) # col-major → (2, 1, 0) + + Args: + layout: Minor-to-major stride permutation, or ``None`` (returned unchanged). + + Returns: + Major-to-minor axis permutation compatible with ``jax.Array.layout``, or ``None``. + """ + if layout is None: + return None + return tuple(sorted(range(len(layout)), key=lambda i: layout[i], reverse=True)) + + +def jax_to_cutlass_layout_order( + layout: Sequence[int] | None, +) -> Sequence[int] | None: + """Converts a JAX layout order (major-to-minor) to CuTeDSL layout order (minor-to-major). + + JAX uses major-to-minor ordering: position ``j`` is the dimension index of the + ``j``-th outermost axis. CuTeDSL uses minor-to-major ordering: ``layout[i]`` + is the stride rank of dimension ``i`` (0 = innermost). + + This is the inverse of :func:`cutlass_to_jax_layout_order`. + + Example:: + + jax_to_cutlass_layout_order((0, 1, 2)) # row-major → (2, 1, 0) + jax_to_cutlass_layout_order((2, 1, 0)) # col-major → (0, 1, 2) + + Args: + layout: Major-to-minor axis permutation, or ``None`` (returned unchanged). + + Returns: + Minor-to-major stride permutation for use as :attr:`TensorSpec.layout`, or ``None``. + """ + if layout is None: + return None + inv = [0] * len(layout) + for i, p in enumerate(layout): + inv[p] = len(layout) - 1 - i + return tuple(inv) def jax_to_cutlass_dtype(dtype): @@ -144,6 +268,16 @@ def from_dlpack(array, assumed_align: int = DEFAULT_CUTLASS_DEVICE_BUFFER_ALIGNM return _from_dlpack(array, assumed_align=assumed_align) +def _validate_permutation(name: str, perm, shape): + if len(perm) != len(shape): + raise ValueError(f"{name} must be same length as shape", perm, shape) + for s in perm: + if s < 0 or s >= len(shape): + raise ValueError(f"Invalid index {s} in {name}", perm, shape) + if len(set(perm)) != len(perm): + raise ValueError(f"{name} has duplicate indices", perm) + + class JaxArray: """Base class for JaxArray argument type. @@ -172,32 +306,21 @@ class JaxArray: order=None, mode=None, static=False, + divisibility=None, ): self.dtype = dtype self.shape = tuple(shape) self.ndim = len(self.shape) self.mem_space = mem_space self.assumed_align = assumed_align + if order is None: order = row_major_layout(shape) if mode is None: mode = default_tensor_mode(shape) - if len(order) != len(shape): - raise ValueError(f"layout must be same length as shape", order, shape) - for s in order: - if s < 0 or s >= len(shape): - raise ValueError(f"Invalid index {s} in stride order", order, shape) - if len(tuple(set(order))) != len(order): - raise ValueError(f"layout has duplicate indices", order) - - if len(mode) != len(shape): - raise ValueError(f"mode must be same length as shape", mode, shape) - for s in mode: - if s < 0 or s >= len(shape): - raise ValueError(f"Invalid index {s} in stride order", mode, shape) - if len(tuple(set(mode))) != len(mode): - raise ValueError(f"mode has duplicate indices", mode) + _validate_permutation("order", order, shape) + _validate_permutation("mode", mode, shape) self.order = tuple(order) self.mode = tuple(mode) @@ -208,6 +331,20 @@ class JaxArray: ) self.static = static + if divisibility is not None: + divisibility = _expand_divisibility(divisibility, self.order, self.ndim) + divisibility = tuple(divisibility) + if len(divisibility) != len(shape): + raise ValueError( + "divisibility must be same length as shape", divisibility, shape + ) + for d in divisibility: + if not (d is None or isinstance(d, int)): + raise ValueError( + f"divisibility entries must be None or integer, got {d!r}" + ) + self.divisibility = divisibility + class JaxArrayValue(JaxArray): """The IR representation of the JaxArray.""" @@ -222,12 +359,15 @@ class JaxArrayValue(JaxArray): order, mode, static, + divisibility=None, ): - super().__init__(dtype, shape, mem_space, assumed_align, order, mode, static) + super().__init__( + dtype, shape, mem_space, assumed_align, order, mode, static, divisibility + ) self.value = ir_value def __str__(self): - return f"JaxArrayValue<{self.value}:{self.dtype}:{self.shape}:{self.order}:{self.mode}:{self.static}>" + return f"JaxArrayValue<{self.value}:{self.dtype}:{self.shape}:{self.order}:{self.mode}:{self.static}:{self.divisibility}>" def __repr__(self): return str(self) @@ -236,38 +376,44 @@ class JaxArrayValue(JaxArray): self, shape, order: tuple[int, ...], *, loc=None, ip=None ): i32 = ir.IntegerType.get_signless(32) - i64 = ir.IntegerType.get_signless(64) - one = arith.constant(i64, 1) - zero = arith.constant(i64, 0) pairs = sorted(zip(shape, order), key=lambda x: x[1]) # Compute strides for each element in order. strides = [1] # static 1 for leading if len(shape) > 1: strides.append(pairs[0][0]) - for i, idx in enumerate(range(len(pairs[:-2]))): - strides.append(arith.muli(pairs[i][0], strides[-1])) + for i in range(len(pairs) - 2): + strides.append(arith.muli(pairs[i + 1][0], strides[-1])) # Apply the order to strides strides_ordered = [] for i in range(len(shape)): strides_ordered.append(strides[order[i]]) - # zero out any stride for a shape of size 1 to align with make_ordered_layout - # We ignore the leading dimension of 1 - final_stride = [] - for i in range(len(shape)): - x = arith.cmpi(0, one, shape[i]) - s = strides_ordered[i] - if isinstance(s, int) and s == 1: - final_stride.append(s) - else: - final_stride.append(arith.select(x, zero, s)) - # Shapes are expected to be int32 so truncate to that before creating layout - shape = tuple([arith.trunci(i32, s) for s in shape]) + shape_i32 = tuple(arith.trunci(i32, s) for s in shape) - return cute.make_layout(shape, stride=tuple(final_stride)) + # Apply per-mode divisibility assumptions so the compiler can exploit alignment. + if self.divisibility is not None: + assumed = [] + for s32, div_spec, static_s in zip( + shape_i32, self.divisibility, self.shape + ): + if isinstance(static_s, int): + # Pure static shape is known even though a dynamic shape is + # used. We can assume the exact shape here. We keep the shape + # as a dynamic value to avoid breaking code that may expect + # a dynamic value. + assumed.append(cute.assume(s32, divby=static_s)) + elif div_spec is not None: + # Using a dynamic value so apply the div_spec if its provided. + assumed.append(cute.assume(s32, divby=div_spec)) + else: + # No divisibility specification for this shape + assumed.append(s32) + shape_i32 = tuple(assumed) + + return cute.make_layout(shape_i32, stride=tuple(strides_ordered)) def _load_dynamic_shapes(self, ffi_buffer, *, loc=None, ip=None): i64 = ir.IntegerType.get_signless(64) @@ -325,7 +471,9 @@ class JaxArrayValue(JaxArray): layout = cute.make_ordered_layout(shape, order=self.order, loc=loc, ip=ip) else: shape = self._load_dynamic_shapes(ffi_buffer) - layout = self._make_ordered_layout_dynamic_strides(shape, self.order) + layout = self._make_ordered_layout_dynamic_strides( + shape, self.order, loc=loc, ip=ip + ) # Apply mode order if self.mode is not None: @@ -346,6 +494,7 @@ class JaxArrayValue(JaxArray): self.order, self.mode, self.static, + self.divisibility, ) @@ -355,20 +504,8 @@ class JaxTracedArray(JaxArray): Traced values are not real tensors or allocated on the device. """ - def __init__( - self, - dtype, - shape, - mem_space, - assumed_align, - order, - mode, - static, - ): - super().__init__(dtype, shape, mem_space, assumed_align, order, mode, static) - def __str__(self): - return f"JaxTracedArray<{self.dtype}:{self.shape}:{self.order}:{self.mode}:{self.static}>" + return f"JaxTracedArray<{self.dtype}:{self.shape}:{self.order}:{self.mode}:{self.static}:{self.divisibility}>" def __repr__(self): return str(self) @@ -387,6 +524,7 @@ class JaxTracedArray(JaxArray): self.order, self.mode, self.static, + self.divisibility, ) def __c_pointers__(self): diff --git a/python/CuTeDSL/cutlass/pipeline/__init__.py b/python/CuTeDSL/cutlass/pipeline/__init__.py index 1263f3fc2..afb31d462 100644 --- a/python/CuTeDSL/cutlass/pipeline/__init__.py +++ b/python/CuTeDSL/cutlass/pipeline/__init__.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/pipeline/helpers.py b/python/CuTeDSL/cutlass/pipeline/helpers.py index ab4c7e57b..495d819b4 100644 --- a/python/CuTeDSL/cutlass/pipeline/helpers.py +++ b/python/CuTeDSL/cutlass/pipeline/helpers.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/pipeline/sm100.py b/python/CuTeDSL/cutlass/pipeline/sm100.py index b804d2a5d..f9c8eb5b2 100644 --- a/python/CuTeDSL/cutlass/pipeline/sm100.py +++ b/python/CuTeDSL/cutlass/pipeline/sm100.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/pipeline/sm90.py b/python/CuTeDSL/cutlass/pipeline/sm90.py index ccbb9ee7b..9a5b3ebb3 100644 --- a/python/CuTeDSL/cutlass/pipeline/sm90.py +++ b/python/CuTeDSL/cutlass/pipeline/sm90.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/torch.py b/python/CuTeDSL/cutlass/torch.py index 754beed35..63757f9e6 100644 --- a/python/CuTeDSL/cutlass/torch.py +++ b/python/CuTeDSL/cutlass/torch.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/utils/__init__.py b/python/CuTeDSL/cutlass/utils/__init__.py index a0909c57a..74783822f 100644 --- a/python/CuTeDSL/cutlass/utils/__init__.py +++ b/python/CuTeDSL/cutlass/utils/__init__.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/utils/blackwell_helpers.py b/python/CuTeDSL/cutlass/utils/blackwell_helpers.py index c0d146c60..f1c009ea5 100644 --- a/python/CuTeDSL/cutlass/utils/blackwell_helpers.py +++ b/python/CuTeDSL/cutlass/utils/blackwell_helpers.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/utils/blockscaled_layout.py b/python/CuTeDSL/cutlass/utils/blockscaled_layout.py index b3461416f..b87c701db 100644 --- a/python/CuTeDSL/cutlass/utils/blockscaled_layout.py +++ b/python/CuTeDSL/cutlass/utils/blockscaled_layout.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/utils/distributed.py b/python/CuTeDSL/cutlass/utils/distributed.py index 019528147..725bf7cb2 100644 --- a/python/CuTeDSL/cutlass/utils/distributed.py +++ b/python/CuTeDSL/cutlass/utils/distributed.py @@ -3,40 +3,66 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA # is strictly prohibited. from functools import partial -from typing import Tuple +from typing import Tuple, Union +import cutlass import cutlass.cute as cute -from cutlass.cute.typing import Pointer, Int32 +from cutlass.cute.typing import Pointer, Int32, Literal from cutlass.cutlass_dsl import T, dsl_user_op from cutlass._mlir import ir from cutlass._mlir.dialects import llvm -from cutlass.cute.typing import Literal +from typing_extensions import deprecated __all__ = [ - # misc + # Deprecated "atomicAdd", "ld_bypass", # Message Passing Lock & Unlock "multimem_red_add1", "red_add1", "spin_lock_atom_cas_relaxed_wait", - # Load & Store + "spin_lock_atom_cas_acquire_wait", + "spin_lock_ld_lt_relaxed_wait", + # Dispatch functions + "multimem_ld_reduce", + "multimem_st", + # Load & Store - 128-bit (16 bytes) "multimem_ld_reduce_8xf16", "multimem_ld_reduce_4xf32", "multimem_ld_reduce_8xbf16", "multimem_ld_reduce_16xe4m3", "multimem_ld_reduce_16xe5m2", + # Load & Store - 64-bit (8 bytes) + "multimem_ld_reduce_4xf16", + "multimem_ld_reduce_2xf32", + "multimem_ld_reduce_4xbf16", + "multimem_ld_reduce_8xe4m3", + "multimem_ld_reduce_8xe5m2", + # Load & Store - 32-bit (4 bytes) + "multimem_ld_reduce_2xf16", + "multimem_ld_reduce_1xf32", + "multimem_ld_reduce_2xbf16", + "multimem_ld_reduce_4xe4m3", + "multimem_ld_reduce_4xe5m2", + # Store "multimem_st_4xb32", + "multimem_st_2xb32", + "multimem_st_1xb32", ] +######################################################## +# Deprecated +######################################################## + +@deprecated("atomicAdd is deprecated, use cute.arch.atomic_add instead") @dsl_user_op def atomicAdd(dst_ptr: Pointer, val: Int32, *, loc=None, ip=None) -> Int32: return cute.arch.atomic_add( @@ -49,16 +75,19 @@ def atomicAdd(dst_ptr: Pointer, val: Int32, *, loc=None, ip=None) -> Int32: ) +@deprecated( + "ld_bypass is deprecated, use cute.arch.load with cop='cv' directly instead" +) @cute.jit def ld_bypass(input_tensor: cute.Tensor): fragment = cute.make_rmem_tensor(input_tensor.layout, input_tensor.element_type) - copy_atom_load = cute.make_copy_atom( + copy_atom = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), input_tensor.element_type, memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE, memory_scope=cute.nvgpu.common.MemoryScope.SYS, ) - cute.copy_atom_call(copy_atom_load, input_tensor, fragment) + cute.copy_atom_call(copy_atom, input_tensor, fragment) vals = fragment.load() return vals @@ -78,7 +107,7 @@ def multimem_red_release_gpu_add1( llvm.inline_asm( None, [lock_ptr.toint().ir_value(loc=loc, ip=ip)], - "multimem.red.release.gpu.global.add.u32 [$0], 1;", + "multimem.red.release.gpu.global.add.s32 [$0], 1;", "l", has_side_effects=True, asm_dialect=0, @@ -97,7 +126,7 @@ def multimem_red_release_sys_add1( llvm.inline_asm( None, [lock_ptr.toint().ir_value(loc=loc, ip=ip)], - "multimem.red.release.sys.global.add.u32 [$0], 1;", + "multimem.red.release.sys.global.add.s32 [$0], 1;", "l", has_side_effects=True, asm_dialect=0, @@ -115,7 +144,7 @@ def multimem_red_relaxed_gpu_add1( llvm.inline_asm( None, [lock_ptr.toint().ir_value(loc=loc, ip=ip)], - "multimem.red.relaxed.gpu.global.add.u32 [$0], 1;", + "multimem.red.relaxed.gpu.global.add.s32 [$0], 1;", "l", has_side_effects=True, asm_dialect=0, @@ -133,7 +162,7 @@ def multimem_red_relaxed_sys_add1( llvm.inline_asm( None, [lock_ptr.toint().ir_value(loc=loc, ip=ip)], - "multimem.red.relaxed.sys.global.add.u32 [$0], 1;", + "multimem.red.relaxed.sys.global.add.s32 [$0], 1;", "l", has_side_effects=True, asm_dialect=0, @@ -166,78 +195,6 @@ def multimem_red_add1( multimem_red_relaxed_sys_add1(lock_ptr=lock_ptr, loc=loc, ip=ip) -@dsl_user_op -def red_release_gpu_add1( - lock_ptr: Pointer, - loc=None, - ip=None, -) -> None: - llvm.inline_asm( - None, - [lock_ptr.toint().ir_value(loc=loc, ip=ip)], - "red.release.gpu.global.add.u32 [$0], 1;", - "l", - has_side_effects=True, - asm_dialect=0, - loc=loc, - ip=ip, - ) - - -@dsl_user_op -def red_release_sys_add1( - lock_ptr: Pointer, - loc=None, - ip=None, -) -> None: - llvm.inline_asm( - None, - [lock_ptr.toint().ir_value(loc=loc, ip=ip)], - "red.release.sys.global.add.u32 [$0], 1;", - "l", - has_side_effects=True, - asm_dialect=0, - loc=loc, - ip=ip, - ) - - -@dsl_user_op -def red_relaxed_gpu_add1( - lock_ptr: Pointer, - loc=None, - ip=None, -) -> None: - llvm.inline_asm( - None, - [lock_ptr.toint().ir_value(loc=loc, ip=ip)], - "red.relaxed.gpu.global.add.u32 [$0], 1;", - "l", - has_side_effects=True, - asm_dialect=0, - loc=loc, - ip=ip, - ) - - -@dsl_user_op -def red_relaxed_sys_add1( - lock_ptr: Pointer, - loc=None, - ip=None, -) -> None: - llvm.inline_asm( - None, - [lock_ptr.toint().ir_value(loc=loc, ip=ip)], - "red.relaxed.sys.global.add.u32 [$0], 1;", - "l", - has_side_effects=True, - asm_dialect=0, - loc=loc, - ip=ip, - ) - - @dsl_user_op def red_add1( lock_ptr: Pointer, @@ -248,18 +205,18 @@ def red_add1( ip=None, ) -> None: """ - add 1 to multicast ptr + add 1 to unicast ptr """ - if scope == "gpu": - if order == "release": - red_release_gpu_add1(lock_ptr=lock_ptr, loc=loc, ip=ip) - elif order == "relaxed": - red_relaxed_gpu_add1(lock_ptr=lock_ptr, loc=loc, ip=ip) - elif scope == "sys": - if order == "release": - red_release_sys_add1(lock_ptr=lock_ptr, loc=loc, ip=ip) - elif order == "relaxed": - red_relaxed_sys_add1(lock_ptr=lock_ptr, loc=loc, ip=ip) + cute.arch.red( + lock_ptr.llvm_ptr, + Int32(1), + op="add", + dtype="s32", + sem=order, + scope=scope, + loc=loc, + ip=ip, + ) @cute.jit @@ -287,20 +244,71 @@ def spin_lock_atom_cas_relaxed_wait( ip=ip, ) + +@cute.jit +def spin_lock_atom_cas_acquire_wait( + lock_ptr: Pointer, + *, + expected_val: Int32, + reset_val: Int32, + scope: Literal["gpu", "sys"], + loc=None, + ip=None, +) -> None: + """ + wait on a spin lock until the expected count is reached. Reset flag to reset_val if the expected count is reached. + """ + result = 0 + while result != expected_val: + result = cute.arch.atomic_cas( + ptr=lock_ptr.llvm_ptr, + cmp=Int32(expected_val), + val=Int32(reset_val), + sem="acquire", + scope=scope, + loc=loc, + ip=ip, + ) + + +@cute.jit +def spin_lock_ld_lt_relaxed_wait( + lock_ptr: Pointer, + *, + expected_val: Int32, + scope: Literal["gpu", "sys"], + loc=None, + ip=None, +) -> None: + """ + wait on a spin lock until the expected count is reached. + """ + result = 0 + while result < expected_val: + result = cute.arch.load( + lock_ptr.llvm_ptr, # addr: Pointer to memory location + Int32, # dtype: Data type to load (Int32) + sem="relaxed", + scope=scope, + loc=loc, # loc: Source location for debugging + ip=ip, # ip: Insertion point in IR + ) + + ######################################################## # Multimem Load & Store ######################################################## +# 128-bit (16 bytes) load-reduce base @dsl_user_op -def multimem_ld_reduce_base( +def multimem_ld_reduce_128bit_base( mc_ptr: Pointer, *, ptx_string: str = "", loc=None, ip=None, ) -> Tuple[Int32, Int32, Int32, Int32]: - # ld reduce 8xf16 elts mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value(loc=loc, ip=ip) return_struct = llvm.inline_asm( ir.Type.parse("!llvm.struct<(i32,i32,i32,i32)>"), @@ -316,28 +324,122 @@ def multimem_ld_reduce_base( return return_regs[0], return_regs[1], return_regs[2], return_regs[3] +# 64-bit (8 bytes) load-reduce base +@dsl_user_op +def multimem_ld_reduce_64bit_base( + mc_ptr: Pointer, + *, + ptx_string: str = "", + loc=None, + ip=None, +) -> Tuple[Int32, Int32]: + mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value(loc=loc, ip=ip) + return_struct = llvm.inline_asm( + ir.Type.parse("!llvm.struct<(i32,i32)>"), + [mc_ptr_int], + ptx_string, + "=r,=r,l", + has_side_effects=True, + asm_dialect=0, + loc=loc, + ip=ip, + ) + return_regs = [llvm.extractvalue(T.i32(), return_struct, [i]) for i in range(2)] + return return_regs[0], return_regs[1] + + +# 32-bit (4 bytes) load-reduce base +@dsl_user_op +def multimem_ld_reduce_32bit_base( + mc_ptr: Pointer, + *, + ptx_string: str = "", + loc=None, + ip=None, +) -> Tuple[Int32]: + mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value(loc=loc, ip=ip) + return_struct = llvm.inline_asm( + ir.Type.parse("!llvm.struct<(i32)>"), + [mc_ptr_int], + ptx_string, + "=r,l", + has_side_effects=True, + asm_dialect=0, + loc=loc, + ip=ip, + ) + return_regs = [llvm.extractvalue(T.i32(), return_struct, [i]) for i in range(1)] + return (return_regs[0],) + + +# 128-bit variants multimem_ld_reduce_8xf16 = partial( - multimem_ld_reduce_base, - ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f32.v4.f16x2 {$0, $1, $2, $3}, [$4];", + multimem_ld_reduce_128bit_base, + ptx_string="multimem.ld_reduce.weak.global.add.acc::f32.v4.f16x2 {$0, $1, $2, $3}, [$4];", ) multimem_ld_reduce_4xf32 = partial( - multimem_ld_reduce_base, - ptx_string="multimem.ld_reduce.sys.relaxed.global.add.v4.f32 {$0, $1, $2, $3}, [$4];", + multimem_ld_reduce_128bit_base, + ptx_string="multimem.ld_reduce.weak.global.add.v4.f32 {$0, $1, $2, $3}, [$4];", ) multimem_ld_reduce_8xbf16 = partial( - multimem_ld_reduce_base, - ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f32.v4.bf16x2 {$0, $1, $2, $3}, [$4];", + multimem_ld_reduce_128bit_base, + ptx_string="multimem.ld_reduce.weak.global.add.acc::f32.v4.bf16x2 {$0, $1, $2, $3}, [$4];", ) multimem_ld_reduce_16xe4m3 = partial( - multimem_ld_reduce_base, - ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f16.v4.e4m3x4 {$0, $1, $2, $3}, [$4];", + multimem_ld_reduce_128bit_base, + ptx_string="multimem.ld_reduce.weak.global.add.acc::f16.v4.e4m3x4 {$0, $1, $2, $3}, [$4];", ) multimem_ld_reduce_16xe5m2 = partial( - multimem_ld_reduce_base, - ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f16.v4.e5m2x4 {$0, $1, $2, $3}, [$4];", + multimem_ld_reduce_128bit_base, + ptx_string="multimem.ld_reduce.weak.global.add.acc::f16.v4.e5m2x4 {$0, $1, $2, $3}, [$4];", +) + +# 64-bit variants +multimem_ld_reduce_4xf16 = partial( + multimem_ld_reduce_64bit_base, + ptx_string="multimem.ld_reduce.weak.global.add.acc::f32.v2.f16x2 {$0, $1}, [$2];", +) +multimem_ld_reduce_2xf32 = partial( + multimem_ld_reduce_64bit_base, + ptx_string="multimem.ld_reduce.weak.global.add.v2.f32 {$0, $1}, [$2];", +) +multimem_ld_reduce_4xbf16 = partial( + multimem_ld_reduce_64bit_base, + ptx_string="multimem.ld_reduce.weak.global.add.acc::f32.v2.bf16x2 {$0, $1}, [$2];", +) +multimem_ld_reduce_8xe4m3 = partial( + multimem_ld_reduce_64bit_base, + ptx_string="multimem.ld_reduce.weak.global.add.acc::f16.v2.e4m3x4 {$0, $1}, [$2];", +) +multimem_ld_reduce_8xe5m2 = partial( + multimem_ld_reduce_64bit_base, + ptx_string="multimem.ld_reduce.weak.global.add.acc::f16.v2.e5m2x4 {$0, $1}, [$2];", +) + +# 32-bit variants +multimem_ld_reduce_2xf16 = partial( + multimem_ld_reduce_32bit_base, + ptx_string="multimem.ld_reduce.weak.global.add.acc::f32.f16x2 {$0}, [$1];", +) +multimem_ld_reduce_1xf32 = partial( + multimem_ld_reduce_32bit_base, + ptx_string="multimem.ld_reduce.weak.global.add.f32 {$0}, [$1];", +) +multimem_ld_reduce_2xbf16 = partial( + multimem_ld_reduce_32bit_base, + ptx_string="multimem.ld_reduce.weak.global.add.acc::f32.bf16x2 {$0}, [$1];", +) +multimem_ld_reduce_4xe4m3 = partial( + multimem_ld_reduce_32bit_base, + ptx_string="multimem.ld_reduce.weak.global.add.acc::f16.e4m3x4 {$0}, [$1];", +) +multimem_ld_reduce_4xe5m2 = partial( + multimem_ld_reduce_32bit_base, + ptx_string="multimem.ld_reduce.weak.global.add.acc::f16.e5m2x4 {$0}, [$1];", ) +# 128-bit store @dsl_user_op def multimem_st_4xb32( mc_ptr: Pointer, @@ -349,15 +451,147 @@ def multimem_st_4xb32( loc=None, ip=None, ) -> None: - # st 4x32 bits of data mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value(loc=loc, ip=ip) llvm.inline_asm( - T.i32(), + None, [mc_ptr_int, x, y, z, w], - "multimem.st.sys.relaxed.global.v4.f32 [$1], {$2, $3, $4, $5};", - "=r,l,r,r,r,r", + "multimem.st.weak.global.v4.f32 [$0], {$1, $2, $3, $4};", + "l,r,r,r,r", has_side_effects=True, asm_dialect=0, loc=loc, ip=ip, ) + + +# 64-bit store +@dsl_user_op +def multimem_st_2xb32( + mc_ptr: Pointer, + x: Int32, + y: Int32, + *, + loc=None, + ip=None, +) -> None: + mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value(loc=loc, ip=ip) + llvm.inline_asm( + None, + [mc_ptr_int, x, y], + "multimem.st.weak.global.v2.f32 [$0], {$1, $2};", + "l,r,r", + has_side_effects=True, + asm_dialect=0, + loc=loc, + ip=ip, + ) + + +# 32-bit store +@dsl_user_op +def multimem_st_1xb32( + mc_ptr: Pointer, + x: Int32, + *, + loc=None, + ip=None, +) -> None: + mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value(loc=loc, ip=ip) + llvm.inline_asm( + None, + [mc_ptr_int, x], + "multimem.st.weak.global.f32 [$0], {$1};", + "l,r", + has_side_effects=True, + asm_dialect=0, + loc=loc, + ip=ip, + ) + + +######################################################## +# Dispatch Functions +######################################################## + + +@dsl_user_op +def multimem_ld_reduce( + mc_ptr: Pointer, + *, + dtype, + num_elements: int, + loc=None, + ip=None, +) -> Union[Tuple[Int32, Int32, Int32, Int32], Tuple[Int32, Int32], Tuple[Int32]]: + """ + Dispatch to appropriate multimem_ld_reduce variant based on dtype and num_elements. + + Args: + mc_ptr: Multicast pointer to load from + dtype: Data type (e.g., cutlass.Float16, cutlass.Float32) + num_elements: Number of dtype elements to load (determines vector width) + + Returns: + Tuple of registers (4 for 128-bit, 2 for 64-bit, 1 for 32-bit) + """ + if dtype == cutlass.Float16: + if num_elements == 8: + return multimem_ld_reduce_8xf16(mc_ptr, loc=loc, ip=ip) + elif num_elements == 4: + return multimem_ld_reduce_4xf16(mc_ptr, loc=loc, ip=ip) + elif num_elements == 2: + return multimem_ld_reduce_2xf16(mc_ptr, loc=loc, ip=ip) + elif dtype == cutlass.Float32: + if num_elements == 4: + return multimem_ld_reduce_4xf32(mc_ptr, loc=loc, ip=ip) + elif num_elements == 2: + return multimem_ld_reduce_2xf32(mc_ptr, loc=loc, ip=ip) + elif num_elements == 1: + return multimem_ld_reduce_1xf32(mc_ptr, loc=loc, ip=ip) + elif dtype == cutlass.BFloat16: + if num_elements == 8: + return multimem_ld_reduce_8xbf16(mc_ptr, loc=loc, ip=ip) + elif num_elements == 4: + return multimem_ld_reduce_4xbf16(mc_ptr, loc=loc, ip=ip) + elif num_elements == 2: + return multimem_ld_reduce_2xbf16(mc_ptr, loc=loc, ip=ip) + elif dtype == cutlass.Float8E4M3FN: + if num_elements == 16: + return multimem_ld_reduce_16xe4m3(mc_ptr, loc=loc, ip=ip) + elif num_elements == 8: + return multimem_ld_reduce_8xe4m3(mc_ptr, loc=loc, ip=ip) + elif num_elements == 4: + return multimem_ld_reduce_4xe4m3(mc_ptr, loc=loc, ip=ip) + elif dtype == cutlass.Float8E5M2: + if num_elements == 16: + return multimem_ld_reduce_16xe5m2(mc_ptr, loc=loc, ip=ip) + elif num_elements == 8: + return multimem_ld_reduce_8xe5m2(mc_ptr, loc=loc, ip=ip) + elif num_elements == 4: + return multimem_ld_reduce_4xe5m2(mc_ptr, loc=loc, ip=ip) + raise ValueError(f"Unsupported dtype={dtype}, num_elements={num_elements}") + + +@dsl_user_op +def multimem_st( + mc_ptr: Pointer, + *regs: Int32, + loc=None, + ip=None, +) -> None: + """ + Dispatch to appropriate multimem_st variant based on number of registers. + + Args: + mc_ptr: Multicast pointer to store to + *regs: 1, 2, or 4 Int32 registers to store + """ + num_regs = len(regs) + if num_regs == 4: + multimem_st_4xb32(mc_ptr, regs[0], regs[1], regs[2], regs[3], loc=loc, ip=ip) + elif num_regs == 2: + multimem_st_2xb32(mc_ptr, regs[0], regs[1], loc=loc, ip=ip) + elif num_regs == 1: + multimem_st_1xb32(mc_ptr, regs[0], loc=loc, ip=ip) + else: + raise ValueError(f"Unsupported number of registers: {num_regs}") diff --git a/python/CuTeDSL/cutlass/utils/dynamic_persistent_tile_scheduler.py b/python/CuTeDSL/cutlass/utils/dynamic_persistent_tile_scheduler.py index f18a4ff38..2e043951d 100644 --- a/python/CuTeDSL/cutlass/utils/dynamic_persistent_tile_scheduler.py +++ b/python/CuTeDSL/cutlass/utils/dynamic_persistent_tile_scheduler.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/utils/gemm/__init__.py b/python/CuTeDSL/cutlass/utils/gemm/__init__.py index e6dcbbe10..2357e392f 100644 --- a/python/CuTeDSL/cutlass/utils/gemm/__init__.py +++ b/python/CuTeDSL/cutlass/utils/gemm/__init__.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/utils/gemm/sm100.py b/python/CuTeDSL/cutlass/utils/gemm/sm100.py index b6f151aa4..9f180b10d 100644 --- a/python/CuTeDSL/cutlass/utils/gemm/sm100.py +++ b/python/CuTeDSL/cutlass/utils/gemm/sm100.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/utils/grouped_gemm_persistent_tile_scheduler.py b/python/CuTeDSL/cutlass/utils/grouped_gemm_persistent_tile_scheduler.py index 227eafbe0..9746d304f 100644 --- a/python/CuTeDSL/cutlass/utils/grouped_gemm_persistent_tile_scheduler.py +++ b/python/CuTeDSL/cutlass/utils/grouped_gemm_persistent_tile_scheduler.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/utils/hardware_info.py b/python/CuTeDSL/cutlass/utils/hardware_info.py index 68f05235a..5549d0e3e 100644 --- a/python/CuTeDSL/cutlass/utils/hardware_info.py +++ b/python/CuTeDSL/cutlass/utils/hardware_info.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/utils/hopper_helpers.py b/python/CuTeDSL/cutlass/utils/hopper_helpers.py index 814c63e09..5268de7db 100644 --- a/python/CuTeDSL/cutlass/utils/hopper_helpers.py +++ b/python/CuTeDSL/cutlass/utils/hopper_helpers.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/utils/layout.py b/python/CuTeDSL/cutlass/utils/layout.py index 1ed47719d..53985b7be 100644 --- a/python/CuTeDSL/cutlass/utils/layout.py +++ b/python/CuTeDSL/cutlass/utils/layout.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/utils/print_latex.py b/python/CuTeDSL/cutlass/utils/print_latex.py index c6402edff..6f6b320b7 100644 --- a/python/CuTeDSL/cutlass/utils/print_latex.py +++ b/python/CuTeDSL/cutlass/utils/print_latex.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/utils/smem_allocator.py b/python/CuTeDSL/cutlass/utils/smem_allocator.py index 01afb3801..be86709eb 100644 --- a/python/CuTeDSL/cutlass/utils/smem_allocator.py +++ b/python/CuTeDSL/cutlass/utils/smem_allocator.py @@ -3,29 +3,29 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA # is strictly prohibited. from typing import Optional, Type, Union, overload +from typing_extensions import deprecated import inspect import cutlass.cute as cute from cutlass.cute.arch import get_dyn_smem, get_dyn_smem_size -from cutlass.cutlass_dsl import CuTeDSL, Int8, Numeric, NumericMeta, dsl_user_op - - -SMEM_CAPACITY_MAP = { - "sm_120": (100 - 1) * 1024, - "sm_103": (228 - 1) * 1024, - "sm_100": (228 - 1) * 1024, - "sm_90": (228 - 1) * 1024, - "sm_80": (164 - 1) * 1024, - "sm_86": (100 - 1) * 1024, - "sm_89": (100 - 1) * 1024, -} +from cutlass.cute.tensor import _Tensor +from cutlass.cutlass_dsl import ( + SMEM_CAPACITY_MAP, + CuTeDSL, + Boolean, + Int8, + Numeric, + NumericMeta, + dsl_user_op, +) +from cutlass._mlir.dialects import cute as _cute_ir class SmemAllocator: @@ -36,6 +36,7 @@ class SmemAllocator: .. note:: - The base pointer is aligned to 1024 bytes upon initialization. + - SmemAllocator will automatically calculate the usage upon kernel launch. - There is no need to explicitly specify shared memory size in kernel launch. - Currently only supports static layouts. Dynamic layouts are not supported. @@ -63,6 +64,7 @@ class SmemAllocator: # use of struct members struct_ptr.alpha = 1.0 struct_ptr.x = 2 + x_ptr = struct_ptr.x.ptr # Allocate array int8_array = smem.allocate_array(Int8, 10) # 10 bytes @@ -158,8 +160,14 @@ class SmemAllocator: alignment = max(byte_alignment, size_or_type.__alignof__()) base_ptr = self.allocate(size_in_bytes, alignment, loc=loc, ip=ip) return size_or_type(base_ptr) - elif isinstance(size_or_type, NumericMeta): - size_in_bytes = cute.ceil_div(size_or_type.width, 8) + elif isinstance( + size_or_type, + ( + NumericMeta, + ), + ): + element_width = size_or_type.width if size_or_type is not Boolean else 8 + size_in_bytes = cute.ceil_div(element_width, 8) base_ptr = self.allocate(size_in_bytes, byte_alignment, loc=loc, ip=ip) return cute.recast_ptr(base_ptr, dtype=size_or_type, loc=loc, ip=ip) else: @@ -195,7 +203,15 @@ class SmemAllocator: @dsl_user_op def allocate_array( - self, element_type: Type[Numeric], num_elems: int = 1, *, loc=None, ip=None + self, + element_type: Union[ + Type[Numeric], + ], + num_elems: int = 1, + *, + byte_alignment: int = 1, + loc=None, + ip=None, ): """Allocate an array of elements in shared memory. @@ -208,15 +224,17 @@ class SmemAllocator: :raises ValueError: If num_elems is less than 1 :raises TypeError: If element_type is not a Numeric type """ - if num_elems < 1: + if cute.is_static(num_elems) and num_elems < 1: raise ValueError("num_elems must be at least 1") if not isinstance(element_type, NumericMeta): raise TypeError( f"value_ty must be a type of Numeric, but got {element_type}" ) + element_width = element_type.width if element_type is not Boolean else 8 + byte_alignment = max(byte_alignment, element_width // 8) ptr = self.allocate( - element_type.width // 8 * num_elems, element_type.width // 8, loc=loc, ip=ip + element_width * num_elems // 8, byte_alignment, loc=loc, ip=ip ) return cute.recast_ptr(ptr, dtype=element_type, loc=loc, ip=ip) @@ -276,11 +294,12 @@ class SmemAllocator: raise NotImplementedError(f"dynamic layout is not supported: {layout}") # At least align the allocation to the natural alignment given by the element type - if element_type.width // 8 > byte_alignment: - byte_alignment = element_type.width // 8 + element_width = element_type.width if element_type is not Boolean else 8 + if element_width // 8 > byte_alignment: + byte_alignment = element_width // 8 # Relevant only for sub-byte data types: verify that the entire allocation is byte-aligned - cosize_in_bits = cute.cosize(layout, loc=loc, ip=ip) * element_type.width + cosize_in_bits = cute.cosize(layout, loc=loc, ip=ip) * element_width assert isinstance(cosize_in_bits, int) if cosize_in_bits % 8 != 0: raise ValueError("invalid allocation that is not byte-aligned") @@ -288,7 +307,8 @@ class SmemAllocator: num_bytes = cosize_in_bits // 8 ptr = self.allocate(num_bytes, byte_alignment, loc=loc, ip=ip) ptr = cute.recast_ptr(ptr, swizzle, dtype=element_type, loc=loc, ip=ip) - return cute.make_tensor(ptr, layout, loc=loc, ip=ip) + tensor = cute.make_tensor(ptr, layout, loc=loc, ip=ip) + return _Tensor(tensor, dtype=element_type, loc=loc, ip=ip) # Set explicit signature for Sphinx documentation to avoid issues with @dsl_user_op decorator diff --git a/python/CuTeDSL/cutlass/utils/static_persistent_tile_scheduler.py b/python/CuTeDSL/cutlass/utils/static_persistent_tile_scheduler.py index 7b8365326..73aa23dfb 100644 --- a/python/CuTeDSL/cutlass/utils/static_persistent_tile_scheduler.py +++ b/python/CuTeDSL/cutlass/utils/static_persistent_tile_scheduler.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/utils/tensor_helpers.py b/python/CuTeDSL/cutlass/utils/tensor_helpers.py index d7dd93fbd..e28afb658 100644 --- a/python/CuTeDSL/cutlass/utils/tensor_helpers.py +++ b/python/CuTeDSL/cutlass/utils/tensor_helpers.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/cutlass/utils/tensormap_manager.py b/python/CuTeDSL/cutlass/utils/tensormap_manager.py index bd06d8893..2e3e9f82f 100644 --- a/python/CuTeDSL/cutlass/utils/tensormap_manager.py +++ b/python/CuTeDSL/cutlass/utils/tensormap_manager.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA @@ -19,8 +19,6 @@ from cutlass.cutlass_dsl import dsl_user_op import cutlass.cute as cute from cutlass import const_expr -from cutlass.cute.core import AddressSpace as _CuteAddressSpace -from cutlass.cute.core import make_ptr as _cute_make_ptr class TensorMapUpdateMode(Enum): @@ -140,25 +138,11 @@ class TensorMapManager: warp_idx = cute.arch.make_warp_uniform( cute.arch.warp_idx(loc=loc, ip=ip), loc=loc, ip=ip ) - if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM): - # Hoist SMEM pointer integer values into warp-uniform registers before - # entering predicated blocks. This avoids predicated R2UR lowering on sm_90a. - uniform_smem_ptrs = tuple( - _cute_make_ptr( - p.dtype, - cute.arch.make_warp_uniform(p.toint(), loc=loc, ip=ip), - mem_space=_CuteAddressSpace.smem, - assumed_align=p.alignment, - ) - for p in tensormap_smem_ptr - ) - else: - uniform_smem_ptrs = tensormap_smem_ptr # updates before touching tensormap in global memory if warp_idx == warp_id: if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM): for copy_atom, tensor, smem_ptr in zip( - tma_copy_atom, tensor_gmem, uniform_smem_ptrs + tma_copy_atom, tensor_gmem, tensormap_smem_ptr ): cute.nvgpu.cpasync.update_tma_descriptor( copy_atom, tensor, smem_ptr, loc=loc, ip=ip @@ -170,7 +154,7 @@ class TensorMapManager: cute.arch.sync_warp(loc=loc, ip=ip) # updates to tensormap in global memory if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM): - for gmem_ptr, smem_ptr in zip(tensormap_gmem_ptr, uniform_smem_ptrs): + for gmem_ptr, smem_ptr in zip(tensormap_gmem_ptr, tensormap_smem_ptr): cute.nvgpu.cpasync.cp_fence_tma_desc_release( gmem_ptr, smem_ptr, loc=loc, ip=ip ) diff --git a/python/CuTeDSL/cutlass/utils/tmem_allocator.py b/python/CuTeDSL/cutlass/utils/tmem_allocator.py index 37eef5682..f1b9b1654 100644 --- a/python/CuTeDSL/cutlass/utils/tmem_allocator.py +++ b/python/CuTeDSL/cutlass/utils/tmem_allocator.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA diff --git a/python/CuTeDSL/prep_editable_install.py b/python/CuTeDSL/prep_editable_install.py index b9f869c12..ac7d258f9 100644 --- a/python/CuTeDSL/prep_editable_install.py +++ b/python/CuTeDSL/prep_editable_install.py @@ -40,6 +40,27 @@ class CutlassDSLSetupError(Exception): pass +def get_package_spec(requirements_path: Optional[Path] = None) -> str: + """ + Return the pip requirement spec for nvidia-cutlass-dsl from requirements.txt. + + If anything goes wrong (file not found, parse failure, line missing), + return PACKAGE_NAME as a safe default. + """ + try: + req_path = requirements_path or Path(__file__).with_name("requirements.txt") + with open(req_path, "r", encoding="utf-8") as f: + for raw_line in f: + line = raw_line.strip() + if not line or line.startswith("#"): + continue + if line.lower().startswith(PACKAGE_NAME): + return line.split("#", 1)[0].strip() + except Exception: + pass + return PACKAGE_NAME + + def download_wheel(temp_dir: Path) -> Path: """ Download the nvidia-cutlass-dsl wheel to a temporary directory. @@ -53,7 +74,10 @@ def download_wheel(temp_dir: Path) -> Path: Raises: CutlassDSLSetupError: If download fails or wheel not found """ - logger.info(f"Downloading {PACKAGE_NAME} wheel to {temp_dir}") + # Resolve package spec from requirements, or fall back to PACKAGE_NAME + package_spec = get_package_spec() + + logger.info(f"Downloading {package_spec} wheel to {temp_dir}") try: subprocess.check_call( @@ -63,7 +87,7 @@ def download_wheel(temp_dir: Path) -> Path: "pip", "download", "--no-deps", - PACKAGE_NAME, + package_spec, "--dest", str(temp_dir), ], @@ -79,7 +103,7 @@ def download_wheel(temp_dir: Path) -> Path: raise CutlassDSLSetupError(error_msg) # Find the downloaded wheel file - wheel_pattern = f"{PACKAGE_NAME.replace('-', '_')}-*.whl" + wheel_pattern = f"*.whl" wheel_files = list(temp_dir.glob(wheel_pattern)) if not wheel_files: raise CutlassDSLSetupError( @@ -108,7 +132,7 @@ def extract_version_from_wheel(wheel_path: Path) -> str: # Construct version regex from package name # Wheel filename format: {package_name_with_underscores}-{version}-{python}-{abi}-{platform}.whl package_pattern = PACKAGE_NAME.replace("-", "_") - version_regex = rf"{re.escape(package_pattern)}-([^-]+)-" + version_regex = rf"{re.escape(package_pattern)}-([^-]+)" version_match = re.match(version_regex, wheel_filename) if version_match: @@ -132,10 +156,7 @@ def extract_version_from_wheel(wheel_path: Path) -> str: return dev_version else: - raise CutlassDSLSetupError( - f"Could not parse version from wheel filename: {wheel_filename}" - ) - + return "9.9.9.dev0" def extract_wheel_contents(wheel_path: Path, extract_dir: Path) -> None: """ diff --git a/python/CuTeDSL/requirements-cu13.txt b/python/CuTeDSL/requirements-cu13.txt index 4fcd9996b..3b3f7463f 100644 --- a/python/CuTeDSL/requirements-cu13.txt +++ b/python/CuTeDSL/requirements-cu13.txt @@ -1,3 +1,3 @@ # Use `pip install -r requirements-cu13.txt` with the present file to install a # wheel consistent with the present state of the github repository -nvidia-cutlass-dsl[cu13]==4.4.2 +nvidia-cutlass-dsl[cu13]==4.5.0 diff --git a/python/CuTeDSL/requirements.txt b/python/CuTeDSL/requirements.txt index 2238c3db3..80d81892e 100644 --- a/python/CuTeDSL/requirements.txt +++ b/python/CuTeDSL/requirements.txt @@ -1,3 +1,3 @@ # Use `pip install -r requirements.txt` with the present file to install a # wheel consistent with the present state of the github repository -nvidia-cutlass-dsl==4.4.2 +nvidia-cutlass-dsl==4.5.0.dev0 diff --git a/python/cutlass_cppgen/__init__.py b/python/cutlass_cppgen/__init__.py index 889b6f453..0cbf25180 100644 --- a/python/cutlass_cppgen/__init__.py +++ b/python/cutlass_cppgen/__init__.py @@ -133,7 +133,7 @@ def get_option_registry(): this._option_registry = OptionRegistry(device_cc()) return this._option_registry -this.__version__ = '4.4.2' +this.__version__ = '4.5.0' from cutlass_cppgen.backend import create_memory_pool from cutlass_cppgen.emit.pytorch import pytorch diff --git a/python/setup_cutlass.py b/python/setup_cutlass.py index 0d1abbb32..98d2e077c 100644 --- a/python/setup_cutlass.py +++ b/python/setup_cutlass.py @@ -51,7 +51,7 @@ setup_pycute.perform_setup() setup( name='cutlass_cppgen', - version='4.4.2', + version='4.5.0', description='CUTLASS Pythonic Interface', package_dir={'': '.'}, packages=[ diff --git a/python/setup_library.py b/python/setup_library.py index 84edebc8c..c88e3320c 100644 --- a/python/setup_library.py +++ b/python/setup_library.py @@ -36,7 +36,7 @@ from setuptools import setup def perform_setup(): setup( name='cutlass_library', - version='4.4.2', + version='4.5.0', description='CUTLASS library generation scripts', packages=['cutlass_library'] ) diff --git a/python/setup_pycute.py b/python/setup_pycute.py index c4ae36938..7892f866c 100644 --- a/python/setup_pycute.py +++ b/python/setup_pycute.py @@ -36,7 +36,7 @@ from setuptools import setup def perform_setup(): setup( name='pycute', - version='4.4.2', + version='4.5.0', description='Python implementation of CuTe', packages=['pycute'], ) diff --git a/test/unit/conv/device/conv2d_testbed.h b/test/unit/conv/device/conv2d_testbed.h index 078d866ef..3374d4e64 100644 --- a/test/unit/conv/device/conv2d_testbed.h +++ b/test/unit/conv/device/conv2d_testbed.h @@ -206,11 +206,11 @@ public: if (result != cudaSuccess) { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } return true; } diff --git a/test/unit/conv/device/conv2d_with_absmax_testbed.h b/test/unit/conv/device/conv2d_with_absmax_testbed.h index 7d8b2e3a9..ea4c395bd 100644 --- a/test/unit/conv/device/conv2d_with_absmax_testbed.h +++ b/test/unit/conv/device/conv2d_with_absmax_testbed.h @@ -444,11 +444,11 @@ struct TestbedConv2dWithAbsMax { if (result != cudaSuccess) { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } return true; } diff --git a/test/unit/conv/device/conv2d_with_broadcast_testbed.h b/test/unit/conv/device/conv2d_with_broadcast_testbed.h index 63d51e068..e8c3b78f3 100644 --- a/test/unit/conv/device/conv2d_with_broadcast_testbed.h +++ b/test/unit/conv/device/conv2d_with_broadcast_testbed.h @@ -269,11 +269,11 @@ public: if (result != cudaSuccess) { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } return true; } diff --git a/test/unit/conv/device/conv2d_with_reduction_testbed.h b/test/unit/conv/device/conv2d_with_reduction_testbed.h index 64f30d396..bf4579818 100644 --- a/test/unit/conv/device/conv2d_with_reduction_testbed.h +++ b/test/unit/conv/device/conv2d_with_reduction_testbed.h @@ -197,11 +197,11 @@ public: if (result != cudaSuccess) { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } return true; } diff --git a/test/unit/conv/device/conv3d_testbed.h b/test/unit/conv/device/conv3d_testbed.h index 7d8052aaa..b59821577 100644 --- a/test/unit/conv/device/conv3d_testbed.h +++ b/test/unit/conv/device/conv3d_testbed.h @@ -199,11 +199,11 @@ public: if (result != cudaSuccess) { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } return true; } diff --git a/test/unit/conv/device/conv3d_with_broadcast_testbed.h b/test/unit/conv/device/conv3d_with_broadcast_testbed.h index dd0807878..e0947236e 100644 --- a/test/unit/conv/device/conv3d_with_broadcast_testbed.h +++ b/test/unit/conv/device/conv3d_with_broadcast_testbed.h @@ -282,11 +282,11 @@ public: if (result != cudaSuccess) { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } return true; } diff --git a/test/unit/gemm/device/gemm_testbed_3x.hpp b/test/unit/gemm/device/gemm_testbed_3x.hpp index 35b053b38..98f580370 100644 --- a/test/unit/gemm/device/gemm_testbed_3x.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x.hpp @@ -2827,13 +2827,11 @@ struct TestbedImpl { if (result != cudaSuccess) { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - - if (properties.sharedMemPerBlockOptin < smem_size) { - printf("failed due to smem_size\n"); - printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); - return false; - } - + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } return true; } diff --git a/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp b/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp index 9b505949e..6fd4f8997 100644 --- a/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp @@ -2022,14 +2022,12 @@ struct TestbedImpl { if (result != cudaSuccess) { throw std::runtime_error("cudaGetDeviceProperties() failed"); - } - - if (properties.sharedMemPerBlockOptin < smem_size) { - printf("failed due to smem_size\n"); - printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); - return false; - } - + } + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } return true; } diff --git a/test/unit/gemm/device/multistage_testbed.h b/test/unit/gemm/device/multistage_testbed.h index 8a1c802b2..a76ad6dab 100644 --- a/test/unit/gemm/device/multistage_testbed.h +++ b/test/unit/gemm/device/multistage_testbed.h @@ -128,11 +128,11 @@ struct MultistageTestbed { if (result != cudaSuccess) { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } return true; } diff --git a/test/unit/gemm/device/testbed.h b/test/unit/gemm/device/testbed.h index ae74296d6..6e5091b69 100644 --- a/test/unit/gemm/device/testbed.h +++ b/test/unit/gemm/device/testbed.h @@ -314,11 +314,11 @@ struct Testbed { if (result != cudaSuccess) { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } return true; } diff --git a/test/unit/gemm/device/testbed_complex.h b/test/unit/gemm/device/testbed_complex.h index 94d7252c9..e23a59f56 100644 --- a/test/unit/gemm/device/testbed_complex.h +++ b/test/unit/gemm/device/testbed_complex.h @@ -130,11 +130,11 @@ struct TestbedComplex : public Testbed { if (result != cudaSuccess) { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } return true; } diff --git a/test/unit/gemm/device/testbed_gemm_with_broadcast.h b/test/unit/gemm/device/testbed_gemm_with_broadcast.h index ccea6ae6c..27efddab8 100644 --- a/test/unit/gemm/device/testbed_gemm_with_broadcast.h +++ b/test/unit/gemm/device/testbed_gemm_with_broadcast.h @@ -403,11 +403,11 @@ struct TestbedGemmWithBroadcast { if (result != cudaSuccess) { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } return true; } diff --git a/test/unit/gemm/device/testbed_gemm_with_reduction.h b/test/unit/gemm/device/testbed_gemm_with_reduction.h index 110aba0b5..b24a4ba47 100644 --- a/test/unit/gemm/device/testbed_gemm_with_reduction.h +++ b/test/unit/gemm/device/testbed_gemm_with_reduction.h @@ -377,11 +377,11 @@ struct TestbedGemmWithReduction { if (result != cudaSuccess) { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } return true; } diff --git a/test/unit/gemm/device/testbed_interleaved.h b/test/unit/gemm/device/testbed_interleaved.h index 97ef98215..0446c8e1c 100644 --- a/test/unit/gemm/device/testbed_interleaved.h +++ b/test/unit/gemm/device/testbed_interleaved.h @@ -131,11 +131,11 @@ struct InterleavedTestbed { if (result != cudaSuccess) { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } return true; } diff --git a/test/unit/gemm/device/testbed_planar_complex.h b/test/unit/gemm/device/testbed_planar_complex.h index 786cc3002..434a2f629 100644 --- a/test/unit/gemm/device/testbed_planar_complex.h +++ b/test/unit/gemm/device/testbed_planar_complex.h @@ -139,11 +139,11 @@ public: if (result != cudaSuccess) { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } return true; } diff --git a/test/unit/gemm/device/testbed_rank2k_universal.h b/test/unit/gemm/device/testbed_rank2k_universal.h index d9e6b50a1..a0971aa34 100644 --- a/test/unit/gemm/device/testbed_rank2k_universal.h +++ b/test/unit/gemm/device/testbed_rank2k_universal.h @@ -300,10 +300,11 @@ struct TestbedRank2KUniversal { if (result != cudaSuccess) { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } return true; } diff --git a/test/unit/gemm/device/testbed_rank_k_universal.h b/test/unit/gemm/device/testbed_rank_k_universal.h index 6f7b584fb..44b391fbf 100644 --- a/test/unit/gemm/device/testbed_rank_k_universal.h +++ b/test/unit/gemm/device/testbed_rank_k_universal.h @@ -287,11 +287,11 @@ struct TestbedRank2KUniversal { if (result != cudaSuccess) { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } return true; } diff --git a/test/unit/gemm/device/testbed_sparse.h b/test/unit/gemm/device/testbed_sparse.h index 4202ce803..f6daee7e7 100644 --- a/test/unit/gemm/device/testbed_sparse.h +++ b/test/unit/gemm/device/testbed_sparse.h @@ -323,11 +323,11 @@ struct SparseTestbed { if (result != cudaSuccess) { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } return true; } diff --git a/test/unit/gemm/device/testbed_splitk.h b/test/unit/gemm/device/testbed_splitk.h index 25bd24e67..236d4656f 100644 --- a/test/unit/gemm/device/testbed_splitk.h +++ b/test/unit/gemm/device/testbed_splitk.h @@ -87,11 +87,11 @@ struct TestbedSplitK : public Testbed { if (result != cudaSuccess) { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } return true; } diff --git a/test/unit/gemm/device/testbed_symm_universal.h b/test/unit/gemm/device/testbed_symm_universal.h index 7684652b9..5520fb94e 100644 --- a/test/unit/gemm/device/testbed_symm_universal.h +++ b/test/unit/gemm/device/testbed_symm_universal.h @@ -326,11 +326,11 @@ struct TestbedSymmUniversal { if (result != cudaSuccess) { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } return true; } diff --git a/test/unit/gemm/device/testbed_trmm_universal.h b/test/unit/gemm/device/testbed_trmm_universal.h index a10294d83..3c0e91bbf 100644 --- a/test/unit/gemm/device/testbed_trmm_universal.h +++ b/test/unit/gemm/device/testbed_trmm_universal.h @@ -363,11 +363,11 @@ struct TestbedTrmmUniversal { if (result != cudaSuccess) { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } return true; } diff --git a/test/unit/gemm/device/testbed_universal.h b/test/unit/gemm/device/testbed_universal.h index 2ecca17a2..db874556f 100644 --- a/test/unit/gemm/device/testbed_universal.h +++ b/test/unit/gemm/device/testbed_universal.h @@ -298,11 +298,11 @@ struct TestbedUniversal { if (result != cudaSuccess) { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - - if (properties.sharedMemPerBlockOptin < smem_size) { - return false; - } - + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } return true; }