From ca4fdbea708ad940c905359788372b8add9f85e0 Mon Sep 17 00:00:00 2001 From: dePaul Miller Date: Wed, 26 Feb 2025 09:44:58 -0800 Subject: [PATCH] Blockwise and Groupwise GEMM for Blackwell and Improvements for Hopper (#2139) - Blockwise and Groupwise GEMM improvements for Hopper. - Blockwise and Groupwise GEMM for Blackwell. - Blockwise Grouped GEMM for Hopper. - Static ScalePromotionInterval for Hopper FP8 GEMMs. Co-authored-by: dePaul Miller <23461061+depaulmillz@users.noreply.github.com> --- ...specialized_gemm_with_blockwise_scaling.cu | 15 +- ...specialized_gemm_with_groupwise_scaling.cu | 32 +- .../hopper_fp8_commandline.hpp | 10 +- ...zed_grouped_gemm_with_blockwise_scaling.cu | 841 ++++++++++++ .../CMakeLists.txt | 61 + .../hopper_fp8_commandline.hpp | 211 +++ .../host/gemm_with_groupwise_scaling.h | 520 ++++++++ .../81_blackwell_gemm_blockwise.cu | 585 +++++++++ .../81_blackwell_gemm_groupwise.cu | 589 +++++++++ .../CMakeLists.txt | 57 + examples/CMakeLists.txt | 2 + .../detail/sm100_blockwise_scale_layout.hpp | 189 +++ .../builders/sm100_blockwise_umma_builder.inl | 304 +++++ .../collective/builders/sm90_gmma_builder.inl | 35 +- .../gemm/collective/collective_builder.hpp | 1 + .../gemm/collective/collective_mma.hpp | 2 + .../gemm/collective/fp8_accumulation.hpp | 24 + ..._mma_warpspecialized_blockwise_scaling.hpp | 1156 +++++++++++++++++ ...ma_gmma_rs_warpspecialized_mixed_input.hpp | 14 + ..._mma_array_tma_gmma_ss_warpspecialized.hpp | 13 + ..._array_tma_gmma_ss_warpspecialized_fp8.hpp | 12 + ..._warpspecialized_fp8_blockwise_scaling.hpp | 1036 +++++++++++++++ ..._warpspecialized_fp8_blockwise_scaling.hpp | 104 +- include/cutlass/gemm/dispatch_policy.hpp | 99 +- .../cutlass/gemm/kernel/gemm_universal.hpp | 1 + ...gemm_tma_warpspecialized_mma_transform.hpp | 1008 ++++++++++++++ ..._array_tma_warpspecialized_cooperative.hpp | 7 +- ...emm_array_tma_warpspecialized_pingpong.hpp | 3 + 28 files changed, 6860 insertions(+), 71 deletions(-) create mode 100644 examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling.cu create mode 100644 examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/CMakeLists.txt create mode 100644 examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp create mode 100644 examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h create mode 100644 examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu create mode 100644 examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu create mode 100644 examples/81_blackwell_gemm_blockwise/CMakeLists.txt create mode 100644 include/cutlass/detail/sm100_blockwise_scale_layout.hpp create mode 100644 include/cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl create mode 100644 include/cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp create mode 100644 include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp create mode 100644 include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mma_transform.hpp diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu index b6dcc178f..021ca31ec 100644 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu @@ -398,6 +398,10 @@ void initialize(const Options &options) { blockscale_tensor_A.sync_device(); blockscale_tensor_B.sync_device(); + // Note : This value has to match the KernelSchedule::ScalePromotionInterval + // Else kernel will fail can_implement() check + // Deprecation Notice : We plan to remove this params member in an upcoming release + // Users can safely delete this line from their code, since the default is already 4 mma_promotion_interval = 4; if (options.save_aux) { @@ -662,9 +666,11 @@ int run(Options &options) // Check if output from CUTLASS kernel and reference kernel are equal or not Result result; - result.passed = verify(options); + if (options.verify) { + result.passed = verify(options); - std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + } // if (!result.passed) { // exit(-1); @@ -674,8 +680,9 @@ int run(Options &options) if (options.iterations > 0) { GpuTimer timer; - timer.start(); - for (int iter = 0; iter < options.iterations; ++iter) { + for (int iter = 0; iter < options.warmup + options.iterations; ++iter) { + if (iter == options.warmup) + timer.start(); CUTLASS_CHECK(gemm.run()); } timer.stop(); diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu index 74aa36141..ccd0941d0 100644 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu @@ -453,6 +453,10 @@ void initialize(const Options &options) { blockscale_tensor_A.sync_device(); blockscale_tensor_B.sync_device(); + // Note : This value has to match the KernelSchedule::ScalePromotionInterval + // Else kernel will fail can_implement() check + // Deprecation Notice : We plan to remove this params member in an upcoming release + // Users can safely delete this line from their code, since the default is already 4 mma_promotion_interval = 4; if (options.save_aux) { @@ -668,14 +672,14 @@ bool verify(const Options &options, const int ScaleMsPerTile tensor_D.sync_host(); bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view()); - if (false) { - std::cout << "tensor_ref_D.host_view() {" << std::endl - << tensor_ref_D.host_view() << std::endl - << "}" << std::endl; - std::cout << "tensor_D.host_view() {" << std::endl - << tensor_D.host_view() << std::endl - << "}" << std::endl; - } +#if 0 + std::cout << "tensor_ref_D.host_view() {" << std::endl + << tensor_ref_D.host_view() << std::endl + << "}" << std::endl; + std::cout << "tensor_D.host_view() {" << std::endl + << tensor_D.host_view() << std::endl + << "}" << std::endl; +#endif if (IsDFp8 && options.save_amax) { abs_max_D.sync_host(); @@ -729,13 +733,15 @@ int run(Options &options) // Check if output from CUTLASS kernel and reference kernel are equal or not Result result; - result.passed = verify(options, ScaleMsPerTile, ScaleNsPerTile); + if (options.verify) { + result.passed = verify(options, ScaleMsPerTile, ScaleNsPerTile); - std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + } - // if (!result.passed) { - // exit(-1); - // } + if (!result.passed) { + exit(-1); + } // Run profiling loop if (options.iterations > 0) diff --git a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp index e8ea5330b..23f05ada0 100644 --- a/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp +++ b/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp @@ -34,6 +34,7 @@ template struct Options { bool help = false; + bool verify = true; float alpha = 1.f, beta = 0.f; float scale_a = 1.f, scale_b = 1.f, scale_c = 1.f, scale_d = 1.f, scale_aux = 1.f; @@ -41,6 +42,7 @@ struct Options { bool save_aux = true; bool save_amax = true; int iterations = 1000; + int warmup = 1000; int m = 1024, n = 512, k = 1024, l = 1; RasterOrderOptions raster; int swizzle; @@ -68,7 +70,9 @@ struct Options { cmd.get_cmd_line_argument("device_scale", device_scale, false); cmd.get_cmd_line_argument("save_aux", save_aux, true); cmd.get_cmd_line_argument("save_amax", save_amax, true); + cmd.get_cmd_line_argument("warmup", warmup); cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("verify", verify); char raster_char; cmd.get_cmd_line_argument("raster", raster_char); @@ -89,8 +93,8 @@ struct Options { /// Prints the usage statement. std::ostream & print_usage(std::ostream &out) const { - out << "54_fp8_hopper_warp_specialized_gemm\n\n" - << " Hopper FP8 GEMM using a Warp Specialized kernel.\n\n" + out << "67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling\n\n" + << " Hopper FP8 GEMM using a Warp Specialized kernel with Blockwise Scaling.\n\n" << "Options:\n\n" << " --help If specified, displays this usage statement\n\n" << " --m= Sets the M extent of the GEMM\n" @@ -113,7 +117,7 @@ struct Options { out << "\n\nExamples:\n\n" - << "$ " << "54_fp8_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + << "$ " << "67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; return out; } diff --git a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling.cu b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling.cu new file mode 100644 index 000000000..d20bad582 --- /dev/null +++ b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling.cu @@ -0,0 +1,841 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Grouped scale Hopper FP8 Grouped GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture + This example demonstrates a grouped scaled FP8 Grouped GEMM using the new CUTLASS 3.0. + APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows: + 1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA) + which are more efficient than the Ampere tensor core instructions. + 2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large + blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous + copies between thread blocks in a cluster. This example also showcases on-the-fly modification of TMA + descriptors to move between groups/problem_count (represented by groups). + 3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details). + 4. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the + CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can + improve performance. + Examples: + $ ./examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling \ + --m=2816 --n=3072 --k=16384 --save_aux=false --save_amax=false \ + --raster=h --swizzle=2 --benchmark=./test_benchmark.txt + + Where the test_benchmark.txt may look as such: + 0 256x512x128 + 1 256x512x512 + 2 512x256x128 + 3 256x256x128 + 4 256x512x1024 + 5 1024x512x128 and so on +*/ + +#include +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +// Includes from examples directory +#include "helper.h" +#include "hopper_fp8_commandline.hpp" +#include "reference/host/gemm_with_groupwise_scaling.h" + +using namespace cute; + +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C matrix configuration +using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// D matrix configuration +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = AlignmentC; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementBlockScale = float; // Element type for blockscaling during accumulation +using ElementCompute = float; // Element type for epilogue computation + +using TileShape_ = Shape<_128,_128,_128>; // This one is just to make the compiler happy with verify()... + +// ScaleGranularity{M,N}: number of {rows in A}/{columns in B} that share the same scaling factor +// Given TileShape = Shape<_128,_128,_128>: +// ScaleGranularityM == 128 and ScaleGranularityN == 128 --> 2Dx2D (the shape of the scaling factor) +// ScaleGranularityM == 1 and ScaleGranularityN == 128 --> 1Dx2D scaling +// ScaleGranularityM == 128 and ScaleGranularityN == 1 --> 2Dx1D scaling +// ScaleGranularityM == 1 and ScaleGranularityN == 1 --> 1Dx1D scaling +template +struct GroupScaleConfig { + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size + using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster + + static constexpr int ScaleGranularityM = ScaleGranularityM_; + static constexpr int ScaleGranularityN = ScaleGranularityN_; + static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; + static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; + + static_assert(size<0>(TileShape{}) == ScaleGranularityM * ScaleMsPerTile, + "FP8 scaling granularity must evenly divide tile shape along M."); + static_assert(size<1>(TileShape{}) == ScaleGranularityN * ScaleNsPerTile, + "FP8 scaling granularity must evenly divide tile shape along N."); + + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + using FusionOperation = cutlass::epilogue::fusion::LinearCombination; +}; + +using GroupScale1D1DConfig = GroupScaleConfig< 1, 1>; +using GroupScale1D2DConfig = GroupScaleConfig< 1, size<1>(TileShape_{})>; +using GroupScale2D1DConfig = GroupScaleConfig(TileShape_{}), 1>; +using GroupScale2D2DConfig = GroupScaleConfig(TileShape_{}), size<1>(TileShape_{})>; + +template +struct GroupScaleGemm { + using ArchTag = typename ScheduleConfig::ArchTag; + using OperatorClass = typename ScheduleConfig::OperatorClass; + using TileShape = typename ScheduleConfig::TileShape; + using ClusterShape = typename ScheduleConfig::ClusterShape; + using KernelSchedule = typename ScheduleConfig::KernelSchedule; + using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; + using EpilogueTileType = typename ScheduleConfig::EpilogueTileType; + using FusionOperation = typename ScheduleConfig::FusionOperation; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + TileShape, ClusterShape, + EpilogueTileType, + ElementAccumulator, ElementCompute, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutD *, AlignmentD, + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloopWithGroupWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloopWithGroupWiseScaling, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +using GroupScale1D1DGemm = GroupScaleGemm; +using GroupScale1D2DGemm = GroupScaleGemm; +using GroupScale2D1DGemm = GroupScaleGemm; +using GroupScale2D2DGemm = GroupScaleGemm; + +// Extract information from Gemm kernel. +using EpilogueOutputOp = typename GroupScale1D1DGemm::Gemm::EpilogueOutputOp; +using ElementScalar = typename EpilogueOutputOp::ElementScalar; +using ActivationFunctor = typename EpilogueOutputOp::ActivationFn; + +using StrideA = typename GroupScale1D1DGemm::Gemm::GemmKernel::InternalStrideA; +using StrideB = typename GroupScale1D1DGemm::Gemm::GemmKernel::InternalStrideB; +using StrideC = typename GroupScale1D1DGemm::Gemm::GemmKernel::InternalStrideC; +using StrideD = typename GroupScale1D1DGemm::Gemm::GemmKernel::InternalStrideD; + +static_assert(cute::is_same_v, + "ElementAccumulator and ElementBlockScale should be same datatype"); + +/// Initialization + +cutlass::DeviceAllocation problem_sizes; + +std::vector offset_A; +std::vector offset_B; +std::vector offset_C; +std::vector offset_D; +std::vector offset_blockscale_A; +std::vector offset_blockscale_B; + +std::vector stride_A_host; +std::vector stride_B_host; +std::vector stride_C_host; +std::vector stride_D_host; + +std::vector alpha_host; +std::vector beta_host; + +uint64_t seed; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation blockscale_block_A; +cutlass::DeviceAllocation blockscale_block_B; + +cutlass::DeviceAllocation ptr_A; +cutlass::DeviceAllocation ptr_B; +cutlass::DeviceAllocation ptr_C; +cutlass::DeviceAllocation ptr_D; +cutlass::DeviceAllocation ptr_ref_D; +cutlass::DeviceAllocation ptr_blockscale_A; +cutlass::DeviceAllocation ptr_blockscale_B; + +cutlass::DeviceAllocation stride_A; +cutlass::DeviceAllocation stride_B; +cutlass::DeviceAllocation stride_C; +cutlass::DeviceAllocation stride_D; + +cutlass::DeviceAllocation alpha_device; +cutlass::DeviceAllocation beta_device; +cutlass::DeviceAllocation block_alpha; +cutlass::DeviceAllocation block_beta; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90GroupParams>::RasterOrderOptions; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023, + ScopeMin scope_min = std::nullopt, ScopeMax scope_max = std::nullopt) { + + double _scope_max, _scope_min; + int bits_input = cutlass::sizeof_bits::value; + if (bits_input == 1) { + _scope_max = 2; + _scope_min = 0; + } else if (bits_input <= 8) { + _scope_max = 2; + _scope_min = -2; + } else if (bits_input == 16) { + _scope_max = 5; + _scope_min = -5; + } else { + _scope_max = 8; + _scope_min = -8; + } + if constexpr (!std::is_same_v) { + _scope_max = scope_max; + } + if constexpr (!std::is_same_v) { + _scope_min = scope_min; + } + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, (Element) _scope_max, (Element) _scope_min, 0); + + return true; +} + +/// Allocates device-side data +template +void allocate(const OptionType &options) { + + using TileShape = typename OptionType::GroupScaleConfig::TileShape; + const int ScaleMsPerTile = OptionType::GroupScaleConfig::ScaleMsPerTile; + const int ScaleNsPerTile = OptionType::GroupScaleConfig::ScaleNsPerTile; + + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + int64_t total_elements_blockscale_A = 0; + int64_t total_elements_blockscale_B = 0; + + offset_A.clear(); + offset_B.clear(); + offset_C.clear(); + offset_D.clear(); + offset_blockscale_A.clear(); + offset_blockscale_B.clear(); + stride_A_host.clear(); + stride_B_host.clear(); + stride_C_host.clear(); + stride_D_host.clear(); + + for (int32_t i = 0; i < options.groups; ++i) { + + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(problem), TileShape{}))); + auto groupscale_m = cute::get<0>(blockscale_shape) * ScaleMsPerTile; // We need to pad along M in scale tensor of A to prevent illegal memory access. + auto groupscale_n = cute::get<1>(blockscale_shape) * ScaleNsPerTile; // We need to pad along N in scale tensor of A to prevent illegal memory access. + auto blockscale_k = cute::get<2>(blockscale_shape); + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + offset_blockscale_A.push_back(total_elements_blockscale_A); + offset_blockscale_B.push_back(total_elements_blockscale_B); + + int64_t elements_A = M * K; + int64_t elements_B = K * N; + int64_t elements_C = M * N; + int64_t elements_D = M * N; + int64_t elements_blockscale_A = groupscale_m * blockscale_k; + int64_t elements_blockscale_B = groupscale_n * blockscale_k; + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + total_elements_blockscale_A += elements_blockscale_A; + total_elements_blockscale_B += elements_blockscale_B; + + stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1})); + stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); + stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1})); + stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1})); + + } + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + block_alpha.reset(options.groups); + block_beta.reset(options.groups); + blockscale_block_A.reset(total_elements_blockscale_A); + blockscale_block_B.reset(total_elements_blockscale_B); +} + +/// Initialize operands to be used in the GEMM and reference GEMM +template +void initialize(const OptionType &options) { + + problem_sizes.reset(options.groups); + problem_sizes.copy_from_host(options.problem_sizes_host.data()); + + std::vector ptr_A_host(options.groups); + std::vector ptr_B_host(options.groups); + std::vector ptr_C_host(options.groups); + std::vector ptr_D_host(options.groups); + std::vector ptr_alpha_host(options.groups); + std::vector ptr_beta_host(options.groups); + std::vector ptr_blockscale_A_host(options.groups); + std::vector ptr_blockscale_B_host(options.groups); + + alpha_host.clear(); + beta_host.clear(); + + for (int i = 0; i < options.groups; i++) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + ptr_blockscale_A_host.at(i) = blockscale_block_A.get() + offset_blockscale_A.at(i); + ptr_blockscale_B_host.at(i) = blockscale_block_B.get() + offset_blockscale_B.at(i); + alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast((rand() % 5) + 1) : options.alpha); + beta_host.push_back((options.beta == FLT_MAX) ? static_cast(rand() % 5) : options.beta); + ptr_alpha_host.at(i) = block_alpha.get() + i; + ptr_beta_host.at(i) = block_beta.get() + i; + } + + ptr_A.reset(options.groups); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(options.groups); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(options.groups); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(options.groups); + ptr_D.copy_from_host(ptr_D_host.data()); + + ptr_blockscale_A.reset(options.groups); + ptr_blockscale_A.copy_from_host(ptr_blockscale_A_host.data()); + + ptr_blockscale_B.reset(options.groups); + ptr_blockscale_B.copy_from_host(ptr_blockscale_B_host.data()); + + stride_A.reset(options.groups); + stride_A.copy_from_host(stride_A_host.data()); + + stride_B.reset(options.groups); + stride_B.copy_from_host(stride_B_host.data()); + + stride_C.reset(options.groups); + stride_C.copy_from_host(stride_C_host.data()); + + stride_D.reset(options.groups); + stride_D.copy_from_host(stride_D_host.data()); + + alpha_device.reset(options.groups); + alpha_device.copy_from_host(ptr_alpha_host.data()); + beta_device.reset(options.groups); + beta_device.copy_from_host(ptr_beta_host.data()); + + initialize_block(block_A, seed + 2022); + initialize_block(block_B, seed + 2023); + initialize_block(block_C, seed + 2024); + initialize_block(blockscale_block_A, seed + 2025, -1, 1); + initialize_block(blockscale_block_B, seed + 2026, -1, 1); + + block_alpha.copy_from_host(alpha_host.data()); + block_beta.copy_from_host(beta_host.data()); + +} + +/// Populates a Gemm::Arguments structure from the given commandline options +template +GemmArguments args_from_options(const OptionType &options, bool host_problem_shapes_available = true) +{ + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + int device_id = 0; + cutlass::KernelHardwareInfo kernel_hw_info = cutlass::KernelHardwareInfo::make_kernel_hardware_info(device_id); + + GemmArguments arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), host_problem_shapes_available ? options.problem_sizes_host.data() : (decltype(options.problem_sizes_host.data())) nullptr}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get(), + ptr_blockscale_A.get(), + ptr_blockscale_B.get() + }, + { + {}, // epilogue.thread + ptr_C.get(), stride_C.get(), + ptr_D.get(), stride_D.get() + }, + kernel_hw_info + }; + + auto &fusion_args = arguments.epilogue.thread; + if (options.alpha != FLT_MAX && options.beta != FLT_MAX) { + // If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches. + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + // Single alpha and beta for all groups + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0}; + } + else { + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups. + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = alpha_device.get(); + fusion_args.beta_ptr_array = beta_device.get(); + // One alpha and beta per each group + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; + } + + arguments.scheduler.raster_order = options.raster; + // The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8) + arguments.scheduler.max_swizzle_size = options.swizzle; + + return arguments; +} + +template +bool verify(const OptionType &options) { + + // + // Compute reference output + // + + std::vector block_A_host(block_A.size()); + std::vector block_B_host(block_B.size()); + std::vector block_C_host(block_C.size()); + std::vector block_D_host_kernel(block_D.size()); + std::vector block_D_host_ref(block_D.size()); + std::vector blockscale_block_A_host(blockscale_block_A.size()); + std::vector blockscale_block_B_host(blockscale_block_B.size()); + + block_A.copy_to_host(block_A_host.data()); + block_B.copy_to_host(block_B_host.data()); + block_C.copy_to_host(block_C_host.data()); + block_D.copy_to_host(block_D_host_kernel.data()); + blockscale_block_A.copy_to_host(blockscale_block_A_host.data()); + blockscale_block_B.copy_to_host(blockscale_block_B_host.data()); + + bool passed = true; + for (int group_idx = 0; group_idx < options.groups; group_idx++) { + // Group scaling tensors shapes based `ScaleGranularityM`, CTA Block (TileShape) and GEMM Problem shape + auto [m, n, k] = options.problem_sizes_host.at(group_idx); + auto gemm_problem_shape = cute::make_shape(m, n, k); + auto blockscale_shape = shape(get<1>(cute::zipped_divide(cute::make_layout(gemm_problem_shape), TileShape_{}))); + auto blockscale_m = cute::get<0>(blockscale_shape); + auto blockscale_n = cute::get<1>(blockscale_shape); + auto blockscale_k = cute::get<2>(blockscale_shape); + auto groupscale_m = blockscale_m * OptionType::GroupScaleConfig::ScaleMsPerTile; + auto groupscale_n = blockscale_n * OptionType::GroupScaleConfig::ScaleNsPerTile; + + // Create instantiation for device reference gemm kernel + auto A = cute::make_tensor(block_A_host.data() + offset_A.at(group_idx), + cute::make_layout( + cute::make_shape(m, k, 1), + stride_A_host.at(group_idx) + ) + ); + auto B = cute::make_tensor(block_B_host.data() + offset_B.at(group_idx), + cute::make_layout( + cute::make_shape(n, k, 1), + stride_B_host.at(group_idx) + ) + ); + auto C = cute::make_tensor(block_C_host.data() + offset_C.at(group_idx), + cute::make_layout( + cute::make_shape(m, n, 1), + stride_C_host.at(group_idx) + ) + ); + auto D = cute::make_tensor(block_D_host_ref.data() + offset_D.at(group_idx), + cute::make_layout( + cute::make_shape(m, n, 1), + stride_D_host.at(group_idx) + ) + ); + + auto blockscale_A = cute::make_tensor(blockscale_block_A_host.data() + offset_blockscale_A.at(group_idx), + cute::make_layout( + cute::make_shape(groupscale_m, blockscale_k, 1), + cute::make_stride(1, groupscale_m, groupscale_m * blockscale_k) + ) + ); + auto blockscale_B = cute::make_tensor(blockscale_block_B_host.data() + offset_blockscale_B.at(group_idx), + cute::make_layout( + cute::make_shape(groupscale_n, blockscale_k, 1), + cute::make_stride(1, groupscale_n, groupscale_n * blockscale_k) + ) + ); + + using unused_t = decltype(D); + + cutlass::reference::host::GettMainloopParams< + ElementAccumulator, + decltype(A), + decltype(B), + decltype(blockscale_A), + decltype(blockscale_B), + TileShape_ + > mainloop_params{ + A, B, // Operand Tensors + blockscale_A, blockscale_B // Groupwise scaling Tensors + }; + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D), + unused_t, // bias + unused_t, // Aux + unused_t, // valpha + unused_t, // vbeta + ActivationFunctor + > epilogue_params; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = alpha_host.at(group_idx); + epilogue_params.beta = beta_host.at(group_idx); + + // get reference result + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + auto this_group_passed = std::equal( + // std::execution::par_unseq, + block_D_host_ref.data() + offset_D.at(group_idx), + block_D_host_ref.data() + offset_D.at(group_idx) + m * n, + block_D_host_kernel.data() + offset_D.at(group_idx) + ); + + passed &= this_group_passed; + +#if 0 + std::cout << "Group: " << group_idx << " M: " << m << " N: " << n << " K: " << k << " Status: " << this_group_passed << std::endl; +#endif + + } + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(OptionType &options, bool host_problem_shapes_available = true) +{ + using TileShape = typename OptionType::GroupScaleConfig::TileShape; + const int ScaleGranularityM = OptionType::GroupScaleConfig::ScaleGranularityM; + const int ScaleGranularityN = OptionType::GroupScaleConfig::ScaleGranularityN; + const int ScaleMsPerTile = OptionType::GroupScaleConfig::ScaleMsPerTile; + const int ScaleNsPerTile = OptionType::GroupScaleConfig::ScaleNsPerTile; + + allocate(options); + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options, host_problem_shapes_available); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::string raster = "Heuristic"; + + if (options.raster == RasterOrderOptions::AlongN) { + raster = "Along N"; + } + else if (options.raster == RasterOrderOptions::AlongM) { + raster = "Along M"; + } + + std::cout << " Problem Sizes, Alpha, Beta " << std::endl; + for (int32_t i = 0; i < options.groups; ++i) { + std::cout << " " << options.problem_sizes_host.at(i); + std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl; + } + std::cout << " Groups : " << options.groups << std::endl; + std::cout << " Tile shape (M, N, K): " << size<0>(TileShape{}) << ", " << size<1>(TileShape{}) << ", " << size<2>(TileShape{}) << std::endl; + std::cout << " ScaleGranularityM: " << ScaleGranularityM << " (ScaleMsPerTile: " << ScaleMsPerTile << ")" << std::endl; + std::cout << " ScaleGranularityN: " << ScaleGranularityN << " (ScaleNsPerTile: " << ScaleNsPerTile << ")" << std::endl; + std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + fflush(stdout); + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) { + std::cerr << "This example requires CUDA 12.3 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + + // + // Parse options + // + + Options options_1d1d; + Options options_1d2d; + Options options_2d1d; + Options options_2d2d; + + options_1d1d.parse(argc, args); + options_1d2d.parse(argc, args); + options_2d1d.parse(argc, args); + options_2d2d.parse(argc, args); + + if (options_1d1d.help) { + options_1d1d.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + + auto run_tests = [&] (bool host_problem_shapes_available = true) { + std::cout << "Grouped GEMM kernel with 1D1D group scale" << std::endl; + run(options_1d1d, host_problem_shapes_available); + std::cout << "Grouped GEMM kernel with 1D2D group scale" << std::endl; + run(options_1d2d, host_problem_shapes_available); + std::cout << "Grouped GEMM kernel with 2D1D group scale" << std::endl; + run(options_2d1d, host_problem_shapes_available); + std::cout << "Grouped GEMM kernel with 2D2D group scale" << std::endl; + run(options_2d2d, host_problem_shapes_available); + std::cout << std::endl; + }; + + std::cout << "Running tests with host problem shapes:" << std::endl; + run_tests(true); + std::cout << "Running tests without host problem shapes:" << std::endl; + run_tests(false); + +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/CMakeLists.txt b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/CMakeLists.txt new file mode 100644 index 000000000..f88b31674 --- /dev/null +++ b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/CMakeLists.txt @@ -0,0 +1,61 @@ +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Note that we set --iterations=0 for all tests below to disable the performance benchmarking. +# Only the correctness check will be run by these commands. + +set(TEST_RANDOM --iterations=0) # Random problem sizes +set(TEST_RANDOM_LARGE_GROUP --groups=500 --iterations=0) # Random problem sizes + +set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes +set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=500 --iterations=0) # Random problem sizes + +set(TEST_EPILOGUE_OP --beta=0.5 --iterations=0) # Random problem sizes +set(TEST_EPILOGUE_OP_LARGE_GROUP --alpha=1.5 --iterations=0) # Random problem sizes + +set(TEST_FIXED --m=2048 --n=5120 --k=512 --groups=50 --iterations=0) # Fixed problem sizes +set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=512 --iterations=0) # Fixed problem sizes + +set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes +set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=500 --iterations=0) # Small problem sizes + +cutlass_example_add_executable( + 68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling + 68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling.cu + TEST_COMMAND_OPTIONS + TEST_RANDOM + TEST_RANDOM_LARGE_GROUP + TEST_EPILOGUE + TEST_EPILOGUE_LARGE_GROUP + TEST_EPILOGUE_OP + TEST_EPILOGUE_OP_LARGE_GROUP + TEST_FIXED + TEST_FIXED_LARGE_GROUP + TEST_SMALL + TEST_SMALL_LARGE_GROUP + ) diff --git a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp new file mode 100644 index 000000000..3e425fe23 --- /dev/null +++ b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp @@ -0,0 +1,211 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Command line options parsing +template +struct Options { + + using RasterOrderOptions = _RasterOrderOptions; + using ProblemShape = _ProblemShape; + using GroupScaleConfig = _GroupScaleConfig; + + bool help = false; + + float alpha = 1.f, beta = 0.f; + int iterations = 1000; + int m = 1024, n = 512, k = 1024, groups = 10; + std::string benchmark_path; + std::vector problem_sizes_host; + int const tma_alignment_bits = 128; + int const alignment = tma_alignment_bits / cutlass::sizeof_bits::value; + int const k_alignment = 128; + int const m_alignment = 128; + int const n_alignment = 128; + + RasterOrderOptions raster; + int swizzle; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("groups", groups); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + + char raster_char; + cmd.get_cmd_line_argument("raster", raster_char); + + if (raster_char == 'N' || raster_char == 'n') { + raster = RasterOrderOptions::AlongN; + } + else if (raster_char == 'M' || raster_char == 'm') { + raster = RasterOrderOptions::AlongM; + } + else if (raster_char == 'H' || raster_char == 'h') { + raster = RasterOrderOptions::Heuristic; + } + + cmd.get_cmd_line_argument("swizzle", swizzle, 1); + cmd.get_cmd_line_argument("benchmark", benchmark_path); + + // Decide how to initialize the problems + if (!benchmark_path.empty()) { + if (!benchmark_problems()) { + problem_sizes_host.clear(); + return; + } + } + else { + randomize_problems(cmd); + } + + } + + void randomize_problems(cutlass::CommandLine &cmd) { + int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1; + cmd.get_cmd_line_argument("m", cmd_line_m); + cmd.get_cmd_line_argument("n", cmd_line_n); + cmd.get_cmd_line_argument("k", cmd_line_k); + + problem_sizes_host.reserve(groups); + + for (int i = groups; i > 0; i--) { + int m = cmd_line_m; + int n = cmd_line_n; + int k = cmd_line_k; + if (m < 1) { + m = m_alignment * ((rand() % (64 * alignment / m_alignment)) + 1); + } + if (n < 1) { + n = n_alignment * ((rand() % (64 * alignment / n_alignment)) + 1); + } + if (k < 1) { + k = k_alignment * ((rand() % (32 * alignment / k_alignment)) + 1); + } + problem_sizes_host.push_back({m, n, k}); + } + } + + /// Load a benchmark + bool benchmark_problems() { + std::ifstream file(benchmark_path); + if (!file.good()) { + return false; + } + + while (file.good()) { + + int idx = -1; + std::string extent_str; + + file >> idx >> extent_str; + + if (idx < 0 || extent_str.empty()) { + break; + } + + cutlass::gemm::GemmCoord extent; + std::vector tokens; + + cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); + + for (int i = 0; i < int(tokens.size()); ++i) { + int x = std::atoi(tokens.at(i).c_str()); + + // round up + if (x % alignment) { + x += (alignment - (x % alignment)); + } + + extent.at(i) = x; + } + + if (extent.product()) { + problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()}); + } + } + groups = static_cast(problem_sizes_host.size()); + + return true; + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling\n\n" + << " Hopper FP8 Grouped GEMM using a Warp Specialized kernel with Blockwise Scaling.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --groups= Sets the number of individual GEMM problems for Grouped GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --raster= CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n" + << " --swizzle= CTA Rasterization swizzle\n\n" + << " --benchmark= Executes a benchmark problem size.\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Number of real-valued multiply-adds + uint64_t fmas = 0ull; + + for (auto const [m, n, k] : problem_sizes_host) { + fmas += static_cast(m) * + static_cast(n) * + static_cast(k); + } + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * uint64_t(fmas); + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; diff --git a/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h new file mode 100644 index 000000000..1a94af670 --- /dev/null +++ b/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h @@ -0,0 +1,520 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GETT in host-side code. +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/gemm/gemm.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/relatively_equal.h" +#include +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::reference::host { + +template +struct ElementTraits { + using type = T; +}; + +template +struct ElementTraits().get()), void> > > { + using type = decltype(std::declval().get()); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ElementAccumulator_, + class TensorA_, // (M, K, L) + class TensorB_, // (N, K, L) + class TensorScaleA_, // (m, k, L) + class TensorScaleB_, // (n, k, L) + class TileShape_ +> +struct GettMainloopParams { + using ElementAccumulator = ElementAccumulator_; + using TensorA = TensorA_; + using TensorB = TensorB_; + using EngineA = typename TensorA::engine_type; + using LayoutA = typename TensorA::layout_type; + using EngineB = typename TensorB::engine_type; + using LayoutB = typename TensorB::layout_type; + + using TensorScaleA = TensorScaleA_; + using TensorScaleB = TensorScaleB_; + using TileShape = TileShape_; + using EngineScaleA = typename TensorScaleA::engine_type; + using EngineScaleB = typename TensorScaleB::engine_type; + + TensorA A{}; + TensorB B{}; + TensorScaleA ScaleA{}; + TensorScaleB ScaleB{}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template< + class ElementScalar_, + class ElementScalingFactor_, + class ElementAccumulator_, + class ElementCompute_, + class TensorC_, // (M, N, L) + class TensorD_, // (M, N, L) + class VectorBias_ = TensorD_, // (M, 1) + class TensorAux_ = TensorD_, // (M, N, L) + class VectorAlpha_ = TensorD_, // (M, 1) + class VectorBeta_ = VectorAlpha_, // (M, 1) + class ActivationFunctor_ = cutlass::epilogue::thread::Identity, + class BiasBinaryOp_ = cutlass::plus, + bool PerColumnBias_ = false +> +struct GettEpilogueParams { + using ElementScalar = ElementScalar_; + using ElementScalingFactor = ElementScalingFactor_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using TensorC = TensorC_; + using TensorD = TensorD_; + using TensorAux = TensorAux_; + using VectorBias = VectorBias_; + using VectorAlpha = VectorAlpha_; + using VectorBeta = VectorBeta_; + using ActivationFunctor = ActivationFunctor_; + using BiasBinaryOp = BiasBinaryOp_; + + using EngineC = typename TensorC::engine_type; + using LayoutC = typename TensorC::layout_type; + using EngineD = typename TensorD::engine_type; + using LayoutD = typename TensorD::layout_type; + static constexpr bool PerColumnBias = PerColumnBias_; + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + + TensorC C{}; + TensorD D{}; + VectorBias Bias{}; + TensorAux Aux{}; + VectorAlpha Valpha{}; + VectorBeta Vbeta{}; + ElementCompute st = ElementCompute(1); + + ElementAccumulator* abs_max_D = nullptr; + ElementAccumulator* abs_max_Aux = nullptr; + + ElementScalingFactor scale_a = ElementScalingFactor(1); + ElementScalingFactor scale_b = ElementScalingFactor(1); + ElementScalingFactor scale_c = ElementScalingFactor(1); + ElementScalingFactor scale_d = ElementScalingFactor(1); + ElementScalingFactor scale_aux = ElementScalingFactor(1); + + bool beta_per_channel_scaling = false; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GETT - General Tensor-Tensor contraction reference kernel with Groupwise scaling +template < + class MainloopParams, + class EpilogueParams +> +void Gett( + MainloopParams const& mainloop_params, + EpilogueParams const& epilogue_params) +{ + + static int constexpr kBlockM = cute::get<0>(typename MainloopParams::TileShape{}); + static int constexpr kBlockN = cute::get<1>(typename MainloopParams::TileShape{}); + // printf("mainloop_params.ScaleA.layout()"); cute::print(mainloop_params.ScaleA.layout()); printf("\n"); + // printf("mainloop_params.ScaleB.layout()"); cute::print(mainloop_params.ScaleB.layout()); printf("\n"); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { + for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { + for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { + typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN]; + gett_mainloop(mainloop_params, m, n, l, acc); + gett_epilogue(epilogue_params, m, n, l, acc); + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GETT - Mainloop +template +void gett_mainloop( + MainloopParams const& mainloop_params, + int64_t m, + int64_t n, + int64_t l, + ElementAccumulator (&acc)[kBlockM][kBlockN]) +{ + + static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B"); + static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B"); + + using cute::raw_pointer_cast; + + using ElementA = typename ElementTraits::type; + using ElementB = typename ElementTraits::type; + using ElementBlockScaleA = typename ElementTraits::type; + using ElementBlockScaleB = typename ElementTraits::type; + + using RingOp = multiply_add; + RingOp fma_op; + + multiplies scale_op; + + static int constexpr kBlockK = cute::get<2>(typename MainloopParams::TileShape{});; + + // Tempo accumulators to seperate blockwise accumulation + typename MainloopParams::ElementAccumulator acc_temp[kBlockM][kBlockN]; + + // Zero out accumulators + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int n_b = 0; n_b < kBlockN; ++n_b) { + acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity + acc_temp[m_b][n_b] = ElementAccumulator(0); + } + } + + const int M = cute::size<0>(mainloop_params.A.layout()); + const int N = cute::size<0>(mainloop_params.B.layout()); + + const int ScaleGranularityM = M / cute::size<0>(mainloop_params.ScaleA.layout()); + const int ScaleGranularityN = N / cute::size<0>(mainloop_params.ScaleB.layout()); + + assert(ScaleGranularityM && M % ScaleGranularityM == 0 && "ScaleGranularityM must divide M"); + assert(ScaleGranularityN && N % ScaleGranularityN == 0 && "ScaleGranularityN must divide N"); + + cute::Tensor blockscale_A = domain_offset(make_coord(m / ScaleGranularityM, _0{}), mainloop_params.ScaleA(_, _, l)); + cute::Tensor blockscale_B = domain_offset(make_coord(n / ScaleGranularityN, _0{}), mainloop_params.ScaleB(_, _, l)); + + // Compute on this k-block + for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) { + + // Load Blockwise scaling factor from blockscale Tensors for B + int64_t block_k = k / kBlockK; + cute::Tensor scale_a = blockscale_A(_, block_k); + cute::Tensor scale_b = blockscale_B(_, block_k); + + // Load A + ElementAccumulator a_frag[kBlockM]; + for (int m_b = 0; m_b < kBlockM; ++m_b) { + if (m + m_b < cute::size<0>(mainloop_params.A.layout())) { + // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. + a_frag[m_b] = static_cast(ElementA(mainloop_params.A(m + m_b, k, l))); + } else { + a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity + } + } + + // Load B + ElementAccumulator b_frag[kBlockN]; + for (int n_b = 0; n_b < kBlockN; ++n_b) { + if (n + n_b < cute::size<0>(mainloop_params.B.layout())) { + // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. + b_frag[n_b] = static_cast(ElementB(mainloop_params.B(n + n_b, k, l))); + } else { + b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity + } + } + + // do compute + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int n_b = 0; n_b < kBlockN; ++n_b) { + acc_temp[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc_temp[m_b][n_b]); + } + } + + // Apply Groupwise-scaling at kBlockK boundary + // (a) Apply group and block scaling factors on the partial accumulated results (acc_temp) at the kBlocK boundary + // (b) Zero-out partial temporary (acc_temp), + // (c) Update permanent (accu) + if ((k+1) % kBlockK == 0) { + for (int m_b = 0; m_b < kBlockM; ++m_b) { + auto scale_a_m_b = scale_a[m_b / ScaleGranularityM]; + for (int n_b = 0; n_b < kBlockN; ++n_b) { + auto scale_b_n_b = scale_b[n_b / ScaleGranularityN]; + ElementAccumulator blockwise_scaled_accum = acc_temp[m_b][n_b] * scale_a_m_b * scale_b_n_b; + acc[m_b][n_b] = blockwise_scaled_accum + acc[m_b][n_b]; + acc_temp[m_b][n_b] = ElementAccumulator(0); + } + } + } + + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GETT - Epilogue +template +void gett_epilogue( + EpilogueParams const& epilogue_params, + int64_t m, + int64_t n, + int64_t l, + ElementAccumulator (&acc)[kBlockM][kBlockN]) +{ + static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B"); + static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B"); + + using cute::raw_pointer_cast; + + using ElementCompute = typename EpilogueParams::ElementCompute; + using ElementC = typename EpilogueParams::TensorC::value_type; + using ElementD = typename EpilogueParams::TensorD::value_type; + using ElementAux = typename EpilogueParams::TensorAux::value_type; + using ElementBias = typename EpilogueParams::VectorBias::value_type; + using ElementScalar = typename EpilogueParams::ElementScalar; + using ElementScalingFactor = typename EpilogueParams::ElementScalingFactor; + using ActivationFunctor = typename EpilogueParams::ActivationFunctor; + using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp; + + constexpr bool PerColBias = EpilogueParams::PerColumnBias; + constexpr bool IsScalingAndAmaxOutputNeeded = + cute::is_same_v or + cute::is_same_v; + + constexpr bool IsScalingAndAmaxAuxOutputNeeded = + cute::is_same_v or + cute::is_same_v; + + constexpr bool IsReLUAuxNeeded = + (cute::is_same_v> or + cute::is_same_v>) and + cute::is_same_v; + constexpr bool IsClamp = + cute::is_same_v>; + + constexpr bool IsBackpropFusion = + cute::is_same_v> or + cute::is_same_v>; + + // Input related converter + NumericConverter accumulator_converter; + NumericConverter source_converter; + NumericConverter bias_converter; + [[maybe_unused]] NumericConverter aux_source_converter; + + // Scale related converter + NumericConverter scale_converter; + NumericConverter scaling_factor_converter; + + // Abs max converter + [[maybe_unused]] NumericConverter abs_max_output_converter; + + // Output related converter + NumericConverter destination_converter; + [[maybe_unused]] NumericConverter aux_destination_converter; + NumericConverter dBias_converter; + + // Epilogue operations + multiply_add epilogue_fma; + multiplies mul; + plus add; + + // Activation operation + + auto activation = [] (ElementCompute x, ElementCompute y = ElementCompute(0)) { + if constexpr (std::is_same_v) { + return x + y; + } else { + return ActivationFunctor()(x, y); + } + }; + + // Bias binary operation + BiasBinaryOp bias_op; + + // Do conversion + ElementCompute converted_alpha = scale_converter(epilogue_params.alpha); + ElementCompute converted_beta = scale_converter(epilogue_params.beta); + ElementCompute converted_scale_a = scaling_factor_converter(epilogue_params.scale_a); + ElementCompute converted_scale_b = scaling_factor_converter(epilogue_params.scale_b); + ElementCompute converted_scale_c = scaling_factor_converter(epilogue_params.scale_c); + ElementCompute converted_scale_d = scaling_factor_converter(epilogue_params.scale_d); + ElementCompute converted_scale_aux = scaling_factor_converter(epilogue_params.scale_aux); + + // Init local var + [[maybe_unused]] ElementCompute local_abs_max_output = ElementCompute(0); + [[maybe_unused]] ElementCompute local_abs_max_aux_output = ElementCompute(0); + + converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b)); + converted_beta = mul(converted_beta, converted_scale_c); + + ElementCompute inter_accum[kBlockM][kBlockN]; + + for (int m_b = 0; m_b < kBlockM; ++m_b) { + ElementCompute local_dBias = ElementCompute(0); + + for (int n_b = 0; n_b < kBlockN; ++n_b) { + if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { + // Convert every type to ElementCompute first, do compute, convert to output type, write it out + ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]); + // per-row alpha + if (raw_pointer_cast(epilogue_params.Valpha.data())) { + converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b)); + } + ElementCompute output = mul(converted_alpha, converted_acc); + + if (raw_pointer_cast(epilogue_params.Bias.data()) && not IsBackpropFusion) { + ElementCompute converted_bias = bias_converter(epilogue_params.Bias(PerColBias ? n + n_b : m + m_b)); + output = bias_op(output, converted_bias); + } + + if (raw_pointer_cast(epilogue_params.C.data())) { + ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l)); + // per-row beta + if (epilogue_params.Vbeta.data()) { + converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b)); + } + output = epilogue_fma(converted_beta, converted_src, output); + } + + if constexpr (IsBackpropFusion) { + ElementAux aux_input = ElementAux(0); + if (raw_pointer_cast(epilogue_params.Aux.data())) { + aux_input = epilogue_params.Aux(m + m_b, n + n_b, l); + } + + output = activation(output, aux_source_converter(aux_input)); + local_dBias = add(local_dBias, output); + } + else { + if (raw_pointer_cast(epilogue_params.Aux.data())) { + auto aux_output = output; + if constexpr (IsScalingAndAmaxAuxOutputNeeded) { + maximum_absolute_value_reduction amax_op; + local_abs_max_aux_output = amax_op(local_abs_max_aux_output, aux_output); + aux_output = epilogue_fma(converted_scale_aux, aux_output, ElementCompute(0)); + } + + if constexpr (IsReLUAuxNeeded) { + epilogue_params.Aux(m + m_b, n + n_b, l) = not (aux_output < 0) ? uint1b_t(1) : uint1b_t(0); + } else { + epilogue_params.Aux(m + m_b, n + n_b, l) = aux_destination_converter(aux_output); + } + } + + if constexpr (IsClamp) { // Treat Clamp as ReLU + output = activation(output, {0, std::numeric_limits::max()}); + } + else { + output = activation(output); + } + } + + if constexpr (IsScalingAndAmaxOutputNeeded) { + maximum_absolute_value_reduction amax_op; + local_abs_max_output = amax_op(local_abs_max_output, output); + output = epilogue_fma(converted_scale_d, output, ElementCompute(0)); + } + + inter_accum[m_b][n_b] = ElementCompute(output); + } + } // n_b + + if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n < cute::size<1>(epilogue_params.D.layout())) { + if (raw_pointer_cast(epilogue_params.Bias.data()) && IsBackpropFusion) { + ElementCompute converted_dBias = bias_converter(epilogue_params.Bias(m + m_b)); + local_dBias = add(local_dBias, converted_dBias); + epilogue_params.Bias(m + m_b) = dBias_converter(local_dBias); + } + } + } // m_b + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int n_b = 0; n_b < kBlockN; ++n_b) { + if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { + epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(inter_accum[m_b][n_b]); + } + } + } + +#if defined(_OPENMP) + #pragma omp critical(Abs_Max_Data_Update) +#endif + { + if constexpr (IsScalingAndAmaxOutputNeeded) { + if (epilogue_params.abs_max_D) { + *epilogue_params.abs_max_D = maximum_with_nan_propogation{}( + *epilogue_params.abs_max_D, abs_max_output_converter(local_abs_max_output)); + } + } + + if constexpr (IsScalingAndAmaxAuxOutputNeeded) { + if (epilogue_params.abs_max_Aux) { + *epilogue_params.abs_max_Aux = maximum_with_nan_propogation{}( + *epilogue_params.abs_max_Aux, abs_max_output_converter(local_abs_max_aux_output)); + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GEMM - General Matrix-Matrix contraction without conjugation options +template < + class MainloopParams, + class EpilogueParams +> +void Gemm3x( + MainloopParams const& mainloop_params, + EpilogueParams const& epilogue_params) +{ + using namespace cute; + + static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{})); + static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{})); + static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{})); + static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "Only Rank3 Tensors (M, K, Batch_Count) " + "with Batchmode are supported"); + // Lower the Matrix-Multiplication with Groupwise scaling (Gemm3x) to a Tensor Contraction (Gett). + Gett(mainloop_params, epilogue_params); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // cutlass::reference::host + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu b/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu new file mode 100644 index 000000000..417830f20 --- /dev/null +++ b/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_blockwise.cu @@ -0,0 +1,585 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief A FP8 blockwise scaled GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS. +*/ + + + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" + +#include "cutlass/util/reference/host/gett.hpp" + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +// A matrix configuration +using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = AlignmentC; + +// MMA type +using ElementAccumulator = float; // Element Accumulator will also be our scale factor type +using ElementCompute = float; + + +// MMA and Cluster Tile Shapes +// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0 +using MmaTileShape_MNK = Shape<_128,_128,_128>; +// Shape of the threadblocks in a cluster +using ClusterShape_MNK = Shape<_1,_1,_1>; + +using ScaleConfig = decltype(cutlass::detail::sm100_trivial_blockwise_scale_config(MmaTileShape_MNK{})); + +using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand +using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand + + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutC, AlignmentD, + cutlass::epilogue::TmaWarpSpecialized1Sm + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, cute::tuple, AlignmentA, + ElementB, cute::tuple, AlignmentB, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100 // Note: Groupwise and Blockwise only support 1 SM MMA at this moment + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; // Default to ClusterLaunchControl (CLC) based tile scheduler + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; +// Strides just iterate over scalars and have no zeros +LayoutSFA layout_SFA; +LayoutSFB layout_SFB; +// Layouts are tiled to the problem size and the strides have zeros +uint64_t seed; + +cutlass::HostTensor tensor_A; +cutlass::HostTensor tensor_SFA; +cutlass::HostTensor tensor_B; +cutlass::HostTensor tensor_SFB; +cutlass::HostTensor tensor_C; +cutlass::HostTensor tensor_D; +cutlass::HostTensor tensor_ref_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help = false; + bool skip_verification = false; + + float alpha = 1.f, beta = 0.f; + int iterations = 1000; + int m = 1024, n = 512, k = 1024, l = 1; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + if (cmd.check_cmd_line_flag("skip-verification")) { + skip_verification = true; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "81_blackwell_gemm_blockwise\n\n" + << " Blackwell FP8 GEMM with Blockwise Scaling using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the l extent (batch) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --iterations= Number of profiling iterations to perform.\n\n" + << " --skip-verification Skip verification.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "81_blackwell_gemm_blockwise" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + throw std::runtime_error("Not implementated."); + } + + return true; +} + +/// Helper to initialize a block of device data (scale_tensors) +template +bool initialize_scale_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + + scope_min = -8; + scope_max = 8; + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + throw std::runtime_error("Not implementated."); + } + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + using namespace cute; + + auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k); + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l)); + + layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l)); + layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l)); + + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); + auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); + auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); + auto blockscale_a_coord = cutlass::make_Coord(size(filter_zeros(layout_SFA))); + auto blockscale_b_coord = cutlass::make_Coord(size(filter_zeros(layout_SFB))); + + tensor_A.resize(a_coord); + tensor_B.resize(b_coord); + tensor_C.resize(c_coord); + tensor_D.resize(c_coord); + tensor_ref_D.resize(c_coord); + tensor_SFA.resize(blockscale_a_coord); + tensor_SFB.resize(blockscale_b_coord); + + initialize_tensor(tensor_A.host_view(), cutlass::Distribution::Uniform, seed + 2022); + initialize_tensor(tensor_B.host_view(), cutlass::Distribution::Uniform, seed + 2023); + initialize_tensor(tensor_C.host_view(), cutlass::Distribution::Uniform, seed + 2024); + + initialize_scale_tensor(tensor_SFA.host_view(), cutlass::Distribution::Uniform, seed + 2025); + initialize_scale_tensor(tensor_SFB.host_view(), cutlass::Distribution::Uniform, seed + 2026); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + + tensor_SFA.sync_device(); + tensor_SFB.sync_device(); + +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {tensor_A.device_data(), stride_A, + tensor_B.device_data(), stride_B, + tensor_SFA.device_data(), layout_SFA, + tensor_SFB.device_data(), layout_SFB}, + { + {}, // epilogue.thread + tensor_C.device_data(), stride_C, + tensor_D.device_data(), stride_D + } + }; + + auto &fusion_args = arguments.epilogue.thread; + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + + return arguments; +} + +bool verify(const Options &options) { + // + // Compute reference output + // + + // Create instantiation for device reference gemm kernel + auto A = cute::make_tensor(tensor_A.host_data(), + cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A)); + auto B = cute::make_tensor(tensor_B.host_data(), + cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B)); + auto C = cute::make_tensor(tensor_C.host_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C)); + auto D = cute::make_tensor(tensor_ref_D.host_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D)); + auto SFA = cute::make_tensor(tensor_SFA.host_data(), layout_SFA); + auto SFB = cute::make_tensor(tensor_SFB.host_data(), layout_SFB); + + using unused_t = decltype(D); + + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, + decltype(A), + decltype(SFA), + decltype(B), + decltype(SFB) + > mainloop_params{A, SFA, B, SFB}; + + cutlass::reference::host::GettEpilogueParams< + ElementAccumulator, + ElementAccumulator, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D) + > epilogue_params; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = options.alpha; + epilogue_params.beta = options.beta; + + // get reference result + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // compare_reference + tensor_D.sync_host(); + bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view()); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + Result result; + if (!options.skip_verification) { + // Check if output from CUTLASS kernel and reference kernel are equal or not + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least sm100a. + + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 10 || props.minor != 0) { + std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl; + return 0; + } + + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Run + // +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu b/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu new file mode 100644 index 000000000..11083e098 --- /dev/null +++ b/examples/81_blackwell_gemm_blockwise/81_blackwell_gemm_groupwise.cu @@ -0,0 +1,589 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief A FP8 groupwise scaled GEMM example for the NVIDIA Blackwell SM100 architecture using CUTLASS. +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" + +#include "cutlass/util/reference/host/gett.hpp" + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +// A matrix configuration +using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = AlignmentC; + +// MMA type +using ElementAccumulator = float; // Element Accumulator will also be our scale factor type +using ElementCompute = float; + + +// MMA and Cluster Tile Shapes +// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster Shape %2 == 0 +using MmaTileShape_MNK = Shape<_128,_128,_128>; +// Shape of the threadblocks in a cluster +using ClusterShape_MNK = Shape<_1,_1,_1>; + +constexpr int ScaleGranularityM = 1; +constexpr int ScaleGranularityN = 128; +constexpr int ScaleGranularityK = 128; +using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig; + +// Note when we have multiple scale factors per tile (in this case 128 scales in M per tile), we will restrict up to a +// 16B alignment if possible (i.e., we have at least 16B of scales in M). +// In this case the smallest M that can be executed is 16. To avoid this for smaller M, you can swap A and B +// and transpose A, B, C, and scales since B^T A^T = C^T. +using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand +using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand + + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutC, AlignmentD, + cutlass::epilogue::TmaWarpSpecialized1Sm + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + ElementA, cute::tuple, AlignmentA, + ElementB, cute::tuple, AlignmentB, + ElementAccumulator, + MmaTileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100 // Note: Groupwise and Blockwise only support 1 SM MMA at this moment + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + void>; // Default to ClusterLaunchControl (CLC) based tile scheduler + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; +// Strides just iterate over scalars and have no zeros +LayoutSFA layout_SFA; +LayoutSFB layout_SFB; +// Layouts are tiled to the problem size and the strides have zeros +uint64_t seed; + +cutlass::HostTensor tensor_A; +cutlass::HostTensor tensor_SFA; +cutlass::HostTensor tensor_B; +cutlass::HostTensor tensor_SFB; +cutlass::HostTensor tensor_C; +cutlass::HostTensor tensor_D; +cutlass::HostTensor tensor_ref_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help = false; + bool skip_verification = false; + + float alpha = 1.f, beta = 0.f; + int iterations = 1000; + int m = 1024, n = 512, k = 1024, l = 1; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + if (cmd.check_cmd_line_flag("skip-verification")) { + skip_verification = true; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "81_blackwell_gemm_groupwise\n\n" + << " Blackwell FP8 GEMM with Groupwise Scaling using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the l extent (batch) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --iterations= Number of profiling iterations to perform.\n\n" + << " --skip-verification Skip verification.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "81_blackwell_gemm_groupwise" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + throw std::runtime_error("Not implementated."); + } + + return true; +} + +/// Helper to initialize a block of device data (scale_tensors) +template +bool initialize_scale_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + + scope_min = -8; + scope_max = 8; + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + throw std::runtime_error("Not implementated."); + } + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + using namespace cute; + + auto gemm_problem_shape = cute::make_shape(options.m, options.n, options.k); + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l)); + + layout_SFA = ScaleConfig::tile_atom_to_shape_SFA(make_shape(options.m, options.n, options.k, options.l)); + layout_SFB = ScaleConfig::tile_atom_to_shape_SFB(make_shape(options.m, options.n, options.k, options.l)); + + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); + auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); + auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); + auto blockscale_a_coord = cutlass::make_Coord(size(filter_zeros(layout_SFA))); + auto blockscale_b_coord = cutlass::make_Coord(size(filter_zeros(layout_SFB))); + + tensor_A.resize(a_coord); + tensor_B.resize(b_coord); + tensor_C.resize(c_coord); + tensor_D.resize(c_coord); + tensor_ref_D.resize(c_coord); + tensor_SFA.resize(blockscale_a_coord); + tensor_SFB.resize(blockscale_b_coord); + + initialize_tensor(tensor_A.host_view(), cutlass::Distribution::Uniform, seed + 2022); + initialize_tensor(tensor_B.host_view(), cutlass::Distribution::Uniform, seed + 2023); + initialize_tensor(tensor_C.host_view(), cutlass::Distribution::Uniform, seed + 2024); + + initialize_scale_tensor(tensor_SFA.host_view(), cutlass::Distribution::Uniform, seed + 2025); + initialize_scale_tensor(tensor_SFB.host_view(), cutlass::Distribution::Uniform, seed + 2026); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + + tensor_SFA.sync_device(); + tensor_SFB.sync_device(); + +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {tensor_A.device_data(), stride_A, + tensor_B.device_data(), stride_B, + tensor_SFA.device_data(), layout_SFA, + tensor_SFB.device_data(), layout_SFB}, + { + {}, // epilogue.thread + tensor_C.device_data(), stride_C, + tensor_D.device_data(), stride_D + } + }; + + auto &fusion_args = arguments.epilogue.thread; + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + + return arguments; +} + +bool verify(const Options &options) { + // + // Compute reference output + // + + // Create instantiation for device reference gemm kernel + auto A = cute::make_tensor(tensor_A.host_data(), + cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A)); + auto B = cute::make_tensor(tensor_B.host_data(), + cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B)); + auto C = cute::make_tensor(tensor_C.host_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C)); + auto D = cute::make_tensor(tensor_ref_D.host_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D)); + auto SFA = cute::make_tensor(tensor_SFA.host_data(), layout_SFA); + auto SFB = cute::make_tensor(tensor_SFB.host_data(), layout_SFB); + + using unused_t = decltype(D); + + cutlass::reference::host::GettBlockScalingMainloopParams< + ElementAccumulator, + decltype(A), + decltype(SFA), + decltype(B), + decltype(SFB) + > mainloop_params{A, SFA, B, SFB}; + + cutlass::reference::host::GettEpilogueParams< + ElementAccumulator, + ElementAccumulator, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D) + > epilogue_params; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = options.alpha; + epilogue_params.beta = options.beta; + + // get reference result + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // compare_reference + tensor_D.sync_host(); + bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view()); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + Result result; + if (!options.skip_verification) { + // Check if output from CUTLASS kernel and reference kernel are equal or not + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least sm100a. + + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major != 10 || props.minor != 0) { + std::cerr << "This example requires a GPU with compute capability 100a)." << std::endl; + return 0; + } + + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Run + // +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + run(options); +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/81_blackwell_gemm_blockwise/CMakeLists.txt b/examples/81_blackwell_gemm_blockwise/CMakeLists.txt new file mode 100644 index 000000000..a4dc34d09 --- /dev/null +++ b/examples/81_blackwell_gemm_blockwise/CMakeLists.txt @@ -0,0 +1,57 @@ + +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +if (CUTLASS_NVCC_ARCHS MATCHES 100a) + +set(TEST_RANDOM --iterations=0) # Random problem sizes + +set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes + +set(TEST_SMALL --m=256 --n=128 --k=128 --iterations=0) # Small problem sizes + +cutlass_example_add_executable( + 81_blackwell_gemm_blockwise + 81_blackwell_gemm_blockwise.cu + TEST_COMMAND_OPTIONS + TEST_RANDOM + TEST_EPILOGUE + TEST_SMALL +) + +cutlass_example_add_executable( + 81_blackwell_gemm_groupwise + 81_blackwell_gemm_groupwise.cu + TEST_COMMAND_OPTIONS + TEST_RANDOM + TEST_EPILOGUE + TEST_SMALL +) + +endif() diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 079adff4c..a1a5c00ae 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -146,6 +146,7 @@ foreach(EXAMPLE 64_ada_fp8_gemm_grouped 65_distributed_gemm 67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling + 68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling 69_hopper_mixed_dtype_grouped_gemm 70_blackwell_gemm 71_blackwell_gemm_with_collective_builder @@ -156,6 +157,7 @@ foreach(EXAMPLE 76_blackwell_conv 77_blackwell_fmha 78_blackwell_emulated_bf16x9_gemm + 81_blackwell_gemm_blockwise ) add_subdirectory(${EXAMPLE}) diff --git a/include/cutlass/detail/sm100_blockwise_scale_layout.hpp b/include/cutlass/detail/sm100_blockwise_scale_layout.hpp new file mode 100644 index 000000000..8f75bd256 --- /dev/null +++ b/include/cutlass/detail/sm100_blockwise_scale_layout.hpp @@ -0,0 +1,189 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +/*! \file + \brief Block Wise Scale configs specific for SM100 Blockwise/Groupwise MMA +*/ + +#pragma once + +#include "cutlass/layout/matrix.h" + +#include "cute/int_tuple.hpp" +#include "cute/atom/mma_traits_sm100.hpp" + +namespace cutlass::detail{ + +///////////////////////////////////////////////////////////////////////////////////////////////// +using namespace cute; + +template +struct Sm100BlockwiseScaleConfig { + + using ShapeSFA = Shape, int32_t>, Shape, int32_t>, int32_t>; + using ShapeSFB = Shape, int32_t>, Shape, int32_t>, int32_t>; + + using StrideSFA = conditional_t,Stride<_0,int32_t>, int32_t>, + Stride,Stride<_0,_1>, int32_t>>; + + using StrideSFB = conditional_t,Stride<_0,int32_t>, int32_t>, + Stride,Stride<_0,_1>, int32_t>>; + + using LayoutSFA = Layout; + using LayoutSFB = Layout; + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layoutSFA() { + return LayoutSFA{}; + } + + template + CUTE_HOST_DEVICE + static constexpr auto + smem_atom_layoutSFA(CtaShape_MNK cta_shape_mnk) { + static_assert(cute::is_static_v, "Expect static CTA shape"); + auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + auto [M, N, K] = cta_shape_mnk; + if constexpr (majorSFA == UMMA::Major::MN) { + return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, Int(CtaShape_MNK{}), SFVecSizeM)>{})); + } + else { + return make_stride(make_stride(_0{}, Int(CtaShape_MNK{}), SFVecSizeK)>{}), make_stride(_0{}, _1{})); + } + }(); + + auto [M, N, K] = cta_shape_mnk; + return make_layout( + make_shape(make_shape(Int{}, Int(CtaShape_MNK{}), SFVecSizeM)>{}), + make_shape(Int{}, Int(CtaShape_MNK{}), SFVecSizeK)>{})), + strides + ); + } + + + CUTE_HOST_DEVICE + static constexpr auto + deduce_layoutSFB() { + return LayoutSFB{}; + } + + template + CUTE_HOST_DEVICE + static constexpr auto + smem_atom_layoutSFB(CtaShape_MNK cta_shape_mnk) { + static_assert(cute::is_static_v, "Expect static CTA shape"); + auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + if constexpr (majorSFA == UMMA::Major::MN) { + return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, Int(CtaShape_MNK{}), SFVecSizeN)>{})); + } + else { + return make_stride(make_stride(_0{}, Int(CtaShape_MNK{}), SFVecSizeK)>{}), make_stride(_0{}, _1{})); + } + }(); + + auto [M, N, K] = cta_shape_mnk; + return make_layout( + make_shape(make_shape(Int{}, Int(CtaShape_MNK{}), SFVecSizeN)>{}), + make_shape(Int{}, Int(CtaShape_MNK{}), SFVecSizeK)>{})), + strides + ); + } + + // The following function is provided for user fill dynamic problem size to the layout_SFA. + template + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_SFA(ProblemShape problem_shape) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + auto [M, N, K, L] = problem_shape_MNKL; + if constexpr (majorSFA == UMMA::Major::MN) { + return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, cute::ceil_div(M, SFVecSizeM))); + } + else { + return make_stride(make_stride(_0{}, cute::ceil_div(K, SFVecSizeK)), make_stride(_0{}, _1{})); + } + }(); + + auto [M, N, K, L] = problem_shape_MNKL; + auto mk_layout = make_layout( + make_shape(make_shape(Int{}, cute::ceil_div(M, SFVecSizeM)), + make_shape(Int{}, cute::ceil_div(K, SFVecSizeK))), + strides + ); + + return make_layout(append(shape(mk_layout), L), append(stride(mk_layout), size(filter_zeros(mk_layout)))); + } + + // The following function is provided for user fill dynamic problem size to the layout_SFB. + template + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_SFB(ProblemShape problem_shape) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + auto strides = [&]() CUTLASS_LAMBDA_FUNC_INLINE { + auto [M, N, K, L] = problem_shape_MNKL; + + if constexpr (majorSFB == UMMA::Major::MN) { + return make_stride(make_stride(_0{}, _1{}), make_stride(_0{}, cute::ceil_div(N, SFVecSizeN))); + } + else { + return make_stride(make_stride(_0{}, cute::ceil_div(K, SFVecSizeK)), make_stride(_0{}, _1{})); + } + }(); + + auto [M, N, K, L] = problem_shape_MNKL; + auto nk_layout = make_layout( + make_shape(make_shape(Int{}, cute::ceil_div(N, SFVecSizeN)), + make_shape(Int{}, cute::ceil_div(K, SFVecSizeK))), + strides + ); + + return make_layout(append(shape(nk_layout), L), append(stride(nk_layout), size(filter_zeros(nk_layout)))); + } + +}; + +template +constexpr auto sm100_trivial_blockwise_scale_config(MmaTileShape_MNK) { + return Sm100BlockwiseScaleConfig(MmaTileShape_MNK{}), size<1>(MmaTileShape_MNK{}), size<2>(MmaTileShape_MNK{})>{}; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::detail diff --git a/include/cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl new file mode 100644 index 000000000..c46687d16 --- /dev/null +++ b/include/cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl @@ -0,0 +1,304 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/gemm/collective/builders/sm100_common.inl" +#include "cutlass/gemm/collective/builders/sm100_pipeline_carveout.inl" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template< + int CapacityBytes, + class ElementA, + class ElementB, + class ElementScalar, + class ScaleShapeMNK, + class TileShapeMNK, + class MainloopPipelineStorage, + class TransformLoadPipelineStorage, + class TransformPipelineStorage, + int stages +> +constexpr int +sm100_compute_stage_count_or_override_blockwise(StageCount stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template< + int CapacityBytes, + class ElementA, + class ElementB, + class ElementScalar, + class ScaleShapeMNK, + class TileShapeMNK, + class MainloopPipelineStorage, + class TransformLoadPipelineStorage, + class TransformPipelineStorage, + int stages +> +constexpr int +sm100_compute_stage_count_or_override_blockwise(cute::Int stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template< + int CapacityBytes, + class ElementA, + class ElementB, + class ElementScalar, + class ScaleShapeMNK, + class TileShapeMNK, + class MainloopPipelineStorage, + class TransformLoadPipelineStorage, + class TransformPipelineStorage, + int carveout_bytes> +constexpr int +sm100_compute_stage_count_or_override_blockwise(StageCountAutoCarveout stage_count) { + // For F8/F6/F4 sub-bytes, ElementA/B will be passed in as uint8_t + // For Planar Complex, ElementA/B will be passed in as cutlass::complex + // Each stage include (CollectiveMma::SharedStorage) + // 1. smem for A and smem for B (CollectiveMma::SharedStorage::TensorStorage) + // 2. one of each of the pipelines + constexpr auto pipeline_bytes = sizeof(MainloopPipelineStorage) + + sizeof(TransformLoadPipelineStorage) + sizeof(TransformPipelineStorage); + + constexpr auto a_bits = cute::sizeof_bits_v; + constexpr auto b_bits = cute::sizeof_bits_v; + constexpr auto scale_bits = cute::sizeof_bits_v; + + constexpr int stage_bytes = + cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + cutlass::bits_to_bytes(scale_bits * size<0>(ScaleShapeMNK{}) * size<2>(ScaleShapeMNK{})) + + cutlass::bits_to_bytes(scale_bits * size<1>(ScaleShapeMNK{}) * size<2>(ScaleShapeMNK{})) + + static_cast(pipeline_bytes); + + return (CapacityBytes - carveout_bytes) / stage_bytes; +} + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ElementA, + class GmemLayoutATagPair, + int AlignmentA, + class ElementB, + class GmemLayoutBTagPair, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm100, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATagPair, + AlignmentA, + ElementB, + GmemLayoutBTagPair, + AlignmentB, + ElementAccumulator, + TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + ClusterShape_MNK, // Static cluster shape or dynamic (int, int, _1) + StageCountType, + KernelScheduleType, + cute::enable_if_t< + not cute::is_tuple_v && not cute::is_tuple_v && + not cute::is_complex_v && not cute::is_complex_v && + cute::is_tuple_v && cute::is_tuple_v && + // Dense Gemm + cute::is_base_of_v && + // Alignment check + detail::sm1xx_gemm_is_aligned()>> +{ + static_assert(cute::is_static_v, "TileShape has to be static"); + static_assert(detail::check_input_datatypes(), "Incorrect input types"); + + using GmemLayoutATag = cute::remove_cvref_t(GmemLayoutATagPair{}))>; + using GmemLayoutSFATag = cute::remove_cvref_t(GmemLayoutATagPair{}))>; + using GmemLayoutBTag = cute::remove_cvref_t(GmemLayoutBTagPair{}))>; + using GmemLayoutSFBTag = cute::remove_cvref_t(GmemLayoutBTagPair{}))>; + + static_assert(cute::depth(GmemLayoutSFATag{}) == 2 and cute::depth(GmemLayoutSFBTag{}) == 2, + "Expect SFA and SFB layout to be depth of two with shape ((SFVecMN, restMN),(SFVecK, restK), L)"); + static_assert(size<1,0>(GmemLayoutSFATag{}) == size<1, 0>(GmemLayoutSFBTag{}), + "SFA and SFB must have equivalent SF vector sizes along K"); + + static constexpr cute::UMMA::Major UmmaMajorA = cutlass::gemm::collective::detail::tag_to_umma_major_A(); + static constexpr cute::UMMA::Major UmmaMajorB = cutlass::gemm::collective::detail::tag_to_umma_major_B(); + + // Data type used by MMA instruction + using ElementAMma = decltype(cutlass::gemm::collective::detail::sm100_kernel_input_element_to_mma_input_element()); + using ElementBMma = decltype(cutlass::gemm::collective::detail::sm100_kernel_input_element_to_mma_input_element()); + + static constexpr bool is_2sm = cute::is_base_of_v || + (not cute::is_base_of_v && + not cute::is_base_of_v && + cute::is_static_v && + cute::get<0>(ClusterShape_MNK{}) % 2 == 0 ); + + static_assert(detail::sm100_gemm_check_for_f8f6f4_mix8bit_requirement(), + "TileSize and MNK Major does not met with MMA Mix 8-bit TMA load requirement" ); + using TiledMma = decltype(detail::sm100_make_trivial_tiled_mma< + ElementAMma, ElementBMma, ElementAccumulator, + decltype(cute::product_each(TileShape_MNK{})), ClusterShape_MNK, + UmmaMajorA, UmmaMajorB, KernelScheduleType>()); + + using ElementAMma_SmemAllocType = cute::conditional_t < 8, uint8_t, ElementAMma>; + using ElementBMma_SmemAllocType = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using AtomThrID = typename TiledMma::AtomThrID; + + using AtomThrShapeMNK = cute::Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + using CtaTileShape_MNK = decltype(cute::shape_div(TileShape_MNK{}, AtomThrShapeMNK{})); + + // ((MMA_TILE_M,MMA_TILE_K), MMA_M, MMA_K) + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(cute::size<0>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + // ((MMA_TILE_N,MMA_TILE_K), MMA_N, MMA_K) + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(cute::size<1>(TileShape_MNK{}), + cute::size<2>(TileShape_MNK{})))); + + using BlockTileA_M = decltype(cute::size<0,0>(MmaShapeA_MK{}) * cute::size<1>(MmaShapeA_MK{})); + using BlockTileA_K = decltype(cute::size<0,1>(MmaShapeA_MK{}) * cute::size<2>(MmaShapeA_MK{})); + using BlockTileB_N = decltype(cute::size<0,0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{})); + using BlockTileB_K = decltype(cute::size<0,1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{})); + + static_assert(BlockTileA_K{} == BlockTileB_K{}, "Block tile Ks should be equal"); + + using SmemShape_M = decltype(shape_div(shape<0>(TileShape_MNK{}), shape_div(shape<0>(TileShape_MNK{}), size<0>(TileShape_MNK{}) / size(AtomThrID{})))); + using SmemShape_N = decltype(shape_div(shape<1>(TileShape_MNK{}), shape_div(shape<1>(TileShape_MNK{}), size<1>(TileShape_MNK{}) / size(AtomThrID{})))); + using SmemShape_K = decltype(cute::get<2>(TileShape_MNK{})); + + using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_A( + ClusterShape_MNK{}, AtomThrID{})); + using GmemTiledCopyB = decltype(cutlass::gemm::collective::detail::sm100_cluster_shape_to_tma_atom_B( + ClusterShape_MNK{}, AtomThrID{})); + + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + UmmaMajorA, ElementAMma_SmemAllocType, SmemShape_M, SmemShape_K>()); + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + UmmaMajorB, ElementBMma_SmemAllocType, SmemShape_N, SmemShape_K>()); + static constexpr uint32_t TotalTmemRows = 128; + static constexpr uint32_t Sm100TmemCapacityColumns = 512; + static constexpr uint32_t TotalTmem = TotalTmemRows * Sm100TmemCapacityColumns; + static constexpr uint32_t AccumulatorPipelineStageCount = (is_2sm || (!is_2sm && size(shape<0,0>(MmaShapeA_MK{}) > 64))) ? + TotalTmem / (cute::size<0>(CtaTileShape_MNK{}) * cute::size<1>(CtaTileShape_MNK{})) + : (Sm100TmemCapacityColumns / cute::size<1>(CtaTileShape_MNK{})) * 2; // 1SM MMA_M = 64 case + static_assert(AccumulatorPipelineStageCount > 0, "Accumulator pipeline stage count must be positive. This error probably means that TileShape_MNK and/or TiledMma::ThrLayoutVMNK are wrong."); + + // Calculate scheduler pipeline stages. Having one more stage than the accumulator allows more latency hiding. + using StrideA = cutlass::gemm::TagToStrideA_t; + using InternalStrideA = cute::remove_pointer_t; + // Grouped GEMM (where Stride type is Stride*) does not use CLC based scheduler. + // SchedulerPipelineStageCount could be set to zero for Grouped GEMM, but we shouldn't define CLC Pipeline's barrier arrays of size zero. + static constexpr uint32_t SchedulerPipelineStageCount = cute::is_same_v ? (AccumulatorPipelineStageCount + 1) : 1; + + static constexpr uint32_t KernelSmemCarveout = detail::Sm100DenseGemmTmaUmmaCarveout< + ClusterShape_MNK, + AccumulatorPipelineStageCount, + SchedulerPipelineStageCount, + detail::CLCResponseSize, + false + >::KernelSmemCarveout; + // Reduce SMEM capacity available for buffers considering barrier allocations. + static constexpr int Sm100ReducedSmemCapacityBytes = cutlass::gemm::collective::detail::sm100_smem_capacity_bytes - KernelSmemCarveout; + + using SmemTileShape = cute::Shape; + using MainloopPipelineStorage = typename cutlass::PipelineTmaUmmaAsync<1>::SharedStorage; + using TransformLoadPipelineStorage = typename cutlass::PipelineAsync<1>::SharedStorage; + using TransformPipelineStorage = typename cutlass::PipelineUmmaAsync<1>::SharedStorage; + + static constexpr int ScaleGranularityM = size<0,0>(GmemLayoutSFATag{}); + static constexpr int ScaleGranularityN = size<0,0>(GmemLayoutSFBTag{}); + static constexpr int ScaleGranularityK = size<1,0>(GmemLayoutSFBTag{}); + + static_assert(size<0>(CtaTileShape_MNK{}) >= ScaleGranularityM, "Scale Granularity must be smaller than or equal to the tile shape"); + static_assert(size<1>(CtaTileShape_MNK{}) >= ScaleGranularityN, "Scale Granularity must be smaller than or equal to the tile shape"); + static_assert(size<2>(CtaTileShape_MNK{}) >= ScaleGranularityK, "Scale Granularity must be smaller than or equal to the tile shape"); + + using BlockTileScale_M = Int(TileShape_MNK{}) / ScaleGranularityM>; + using BlockTileScale_N = Int(TileShape_MNK{}) / ScaleGranularityN>; + using BlockTileScale_K = Int(TileShape_MNK{}) / ScaleGranularityK>; + + using ScaleTileShape = cute::Shape; + + static constexpr int PipelineStages = cutlass::gemm::collective::detail::sm100_compute_stage_count_or_override_blockwise< + Sm100ReducedSmemCapacityBytes, ElementAMma_SmemAllocType, ElementBMma_SmemAllocType, + ElementAccumulator, ScaleTileShape, SmemTileShape, MainloopPipelineStorage, + TransformLoadPipelineStorage, TransformPipelineStorage>(StageCountType{}); + static_assert(PipelineStages > 0, "Smem usage is too high. Can't create any SMEM buffers for A, B, and scales."); + + using DispatchPolicy = cutlass::gemm::MainloopSm100TmaUmmaWarpSpecializedBlockwiseScaling< + PipelineStages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape_MNK>; + + using CollectiveOp = cutlass::gemm::collective::CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + cute::tuple, cutlass::gemm::TagToStrideA_t>, + ElementB, + cute::tuple, cutlass::gemm::TagToStrideB_t>, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + void, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + void, + cute::identity + >; +}; + + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl index 504abcc66..01c78f5ae 100644 --- a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl @@ -1046,8 +1046,7 @@ template < class TileShape_MNK, class ClusterShape_MNK, class StageCountType, - int ScaleGranularityM_, - int ScaleGranularityN_ + class KernelScheduleType > struct CollectiveBuilder< arch::Sm90, @@ -1062,11 +1061,16 @@ struct CollectiveBuilder< TileShape_MNK, ClusterShape_MNK, StageCountType, - KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum, + KernelScheduleType, cute::enable_if_t< - not detail::is_use_rmem_A()> + cute::is_same_v and + not detail::is_use_rmem_A() + > > { - using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum; + + static constexpr auto ScaleGranularityM_ = KernelScheduleType::ScaleGranularityM; + static constexpr auto ScaleGranularityN_ = KernelScheduleType::ScaleGranularityN; + static constexpr auto ScalePromotionInterval_ = KernelScheduleType::ScalePromotionInterval; static_assert(is_static::value); static_assert(is_static::value); @@ -1076,12 +1080,12 @@ struct CollectiveBuilder< static_assert(detail::is_aligned(), "Should meet TMA alignment requirement\n"); - static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v); + static constexpr bool IsArrayOfPointersGemm = ( + cute::is_base_of_v || + cute::is_base_of_v); + static constexpr bool IsFP8Input = detail::is_input_fp8(); - static_assert((!IsFP8Input || !IsArrayOfPointersGemm), - "KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now."); + static_assert(IsFP8Input, "Warp Specialized gemm with FP8 BlockScaled Accumulator is only compatible with FP8 Blocked Scaled version right now."); // For fp32 types, map to tf32 MMA value type using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; @@ -1091,10 +1095,9 @@ struct CollectiveBuilder< static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); - static constexpr bool IsCooperative = cute::is_any_of_v>; + static constexpr bool IsCooperative = cute::is_base_of_v || + cute::is_base_of_v; + using AtomLayoutMNK = cute::conditional_t>, Layout>>; @@ -1121,7 +1124,9 @@ struct CollectiveBuilder< static constexpr int PipelineStages = detail::compute_stage_count_with_blockwise_scale(StageCountType{}); - using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8; + using DispatchPolicy = cute::conditional_t, + MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8>; using SmemCopyAtomA = void; using SmemCopyAtomB = void; diff --git a/include/cutlass/gemm/collective/collective_builder.hpp b/include/cutlass/gemm/collective/collective_builder.hpp index 9623900b2..dd139c281 100644 --- a/include/cutlass/gemm/collective/collective_builder.hpp +++ b/include/cutlass/gemm/collective/collective_builder.hpp @@ -43,6 +43,7 @@ #include "cutlass/gemm/collective/builders/sm100_umma_builder.inl" #include "cutlass/gemm/collective/builders/sm100_9xBF16_umma_builder.inl" #include "cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl" +#include "cutlass/gemm/collective/builders/sm100_blockwise_umma_builder.inl" #endif ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/collective_mma.hpp b/include/cutlass/gemm/collective/collective_mma.hpp index e33c06a77..792bde608 100644 --- a/include/cutlass/gemm/collective/collective_mma.hpp +++ b/include/cutlass/gemm/collective/collective_mma.hpp @@ -50,6 +50,7 @@ #include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp" #include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp" +#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp" #if !defined(__CUDACC_RTC__) #include "cutlass/gemm/collective/sm100_mma_warpspecialized.hpp" @@ -59,5 +60,6 @@ #include "cutlass/gemm/collective/sm100_blockscaled_mma_warpspecialized.hpp" #include "cutlass/gemm/collective/sm100_blockscaled_mma_array_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp" #endif // !defined(__CUDACC_RTC__) ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/fp8_accumulation.hpp b/include/cutlass/gemm/collective/fp8_accumulation.hpp index 9dff91a5d..9597fbe74 100644 --- a/include/cutlass/gemm/collective/fp8_accumulation.hpp +++ b/include/cutlass/gemm/collective/fp8_accumulation.hpp @@ -223,6 +223,30 @@ public: mma_count_ = 0; } } + + /// scale (multiply_add) the results from the MMA accumulators to main accumulator without checking the counter. + CUTLASS_DEVICE + void scale(ElementAccumulator const &scale) { + scale_core(scale); + } + + template < + class EngineScale, + class LayoutScale> + CUTLASS_DEVICE + void scale(const cute::Tensor &scale) { + scale_core(scale); + } + + template < + class EngineScaleA, + class LayoutScaleA, + class EngineScaleB, + class LayoutScaleB> + CUTLASS_DEVICE + void scale(const cute::Tensor &scaleA, const cute::Tensor &scaleB) { + scale_core(scaleA, scaleB); + } /// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed. CUTLASS_DEVICE diff --git a/include/cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp new file mode 100644 index 000000000..8fc9331cc --- /dev/null +++ b/include/cutlass/gemm/collective/sm100_mma_warpspecialized_blockwise_scaling.hpp @@ -0,0 +1,1156 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/cluster.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/trace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/detail/sm100_tmem_helper.hpp" +#include "cutlass/detail/sm100_blockwise_scale_layout.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +// Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one +template < + int Stages, + int SchedulerPipelineStageCount, + int AccumulatorPipelineStageCount, + class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) + class ElementA_, + class StridePairA_, + class ElementB_, + class StridePairB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm100TmaUmmaWarpSpecializedBlockwiseScaling< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>, + TileShape_, + ElementA_, + StridePairA_, + ElementB_, + StridePairB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using TiledMma = TiledMma_; + using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; + + using DispatchPolicy = MainloopSm100TmaUmmaWarpSpecializedBlockwiseScaling< + Stages, + SchedulerPipelineStageCount, + AccumulatorPipelineStageCount, + ClusterShape>; + using TileShape = TileShape_; + + using ElementA = ElementA_; + using ElementAMma = typename TiledMma::ValTypeA; + using StrideA = cute::remove_cvref_t(StridePairA_{}))>; + using LayoutSFA = cute::remove_cvref_t(StridePairA_{}))>; + using ElementSFA = typename TiledMma::ValTypeC; + using ElementB = ElementB_; + using ElementBMma = typename TiledMma::ValTypeB; + using StrideB = cute::remove_cvref_t(StridePairB_{}))>; + using LayoutSFB = cute::remove_cvref_t(StridePairB_{}))>; + using ElementSFB = typename TiledMma::ValTypeC; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + + static constexpr int ScaleGranularityM = size<0,0>(LayoutSFA{}); + static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; + static_assert(size<0>(TileShape{}) % ScaleGranularityM == 0 and ScaleGranularityM <= size<0>(TileShape{}), "Scale Granularity M must divide Tile Shape"); + + static constexpr int ScaleGranularityN = size<0,0>(LayoutSFB{}); + static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; + static_assert(size<1>(TileShape{}) % ScaleGranularityN == 0 and ScaleGranularityN <= size<1>(TileShape{}), "Scale Granularity N must divide Tile Shape"); + + static_assert(size<1, 0>(LayoutSFA{}) == size<1, 0>(LayoutSFB{}), "Vector size K must be equal for SFA and SFB"); + + static constexpr int ScaleGranularityK = size<1, 0>(LayoutSFA{}); + static constexpr int ScaleKsPerTile = size<2>(TileShape{}) / ScaleGranularityK; + static_assert(size<2>(TileShape{}) % ScaleGranularityK == 0 and ScaleGranularityK <= size<2>(TileShape{}), "Scale Granularity K must divide Tile Shape"); + static_assert(ScaleGranularityK % size<2>(typename TiledMma::AtomShape_MNK{}) == 0, "Scale Granularity K must be divisible by MMA_K"); + + static constexpr int K_BLOCK_MMAS_PER_SCALE_K = ScaleGranularityK / size<2>(typename TiledMma::AtomShape_MNK{}); + + static constexpr int TILE_M = size<0>(TileShape{}); + static constexpr int TILE_N = size<1>(TileShape{}); + + using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig(LayoutSFA{}.stride()) == 1 ? UMMA::Major::MN : UMMA::Major::K, + size<0,1>(LayoutSFB{}.stride()) == 1 ? UMMA::Major::MN : UMMA::Major::K>; + + CUTE_STATIC_ASSERT_V(evenly_divides(TileShape{}, tile_shape(TiledMma{})), + "Static cluster shape used: TileShape should be evenly divided by TiledMma"); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + + static_assert(size<0>(AtomThrShapeMNK{}) == 1, "2SM MMA is not yet supported"); + + static_assert(size<0>(CtaShape_MNK{}) >= ScaleGranularityM, "Scale Granularity must be smaller than or equal to the tile shape"); + static_assert(size<1>(CtaShape_MNK{}) >= ScaleGranularityN, "Scale Granularity must be smaller than or equal to the tile shape"); + static_assert(size<2>(CtaShape_MNK{}) >= ScaleGranularityK, "Scale Granularity must be smaller than or equal to the tile shape"); + + using SmemLayoutAtomSFA = decltype(ScaleConfig::smem_atom_layoutSFA(CtaShape_MNK{})); + using SmemLayoutAtomSFB = decltype(ScaleConfig::smem_atom_layoutSFB(CtaShape_MNK{})); + + // Define A and B block shapes for reduced size TMA_LOADs + using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); + using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< + DispatchPolicy::Stages, + ClusterShape, + AtomThrShapeMNK>; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + using Load2TransformPipeline = cutlass::PipelineAsync; + using Load2TransformPipelineState = typename Load2TransformPipeline::PipelineState; + + using Mma2TransformPipeline = cutlass::PipelineUmmaAsync< + AccumulatorPipelineStageCount, + AtomThrShapeMNK>; + using Mma2TransformPipelineState = typename Mma2TransformPipeline::PipelineState; + + // Two arrivals per CTA (1 arrival and 1 arrival through cp.async.mbarrier) + static constexpr int NumLoad2TransformProducerThreadEvents = 2; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtomA must be rank 2 (M,K)"); + static_assert(((size<0,0>(MmaShapeA_MK{}) * size<1>(MmaShapeA_MK{})) % size<0>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeA_MK{}) * size<2>(MmaShapeA_MK{})) % size<1>(SmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtomB must be rank 2 (N,K)"); + static_assert(((size<0,0>(MmaShapeB_NK{}) * size<1>(MmaShapeB_NK{})) % size<0>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(((size<0,1>(MmaShapeB_NK{}) * size<2>(MmaShapeB_NK{})) % size<1>(SmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert(cute::is_void_v, + "SM100 UMMA cannot have a non-void copy atom for smem sourced instructions."); + + // Tile along K mode first before tiling over MN. PIPE mode last as usual. + // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + // (MMA_TILE_M,MMA_TILE_K),MMA_M,MMA_K,PIPE) + using SmemLayoutA = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomA{}, + append(MmaShapeA_MK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + // (MMA_TILE_N,MMA_TILE_K),MMA_N,MMA_K,PIPE) + using SmemLayoutB = decltype(UMMA::tile_to_mma_shape( + SmemLayoutAtomB{}, + append(MmaShapeB_NK{}, Int{}), + cute::conditional_t(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + static_assert( + (size(AtomThrShapeMNK{}) == 1 && + (cute::is_same_v || cute::is_same_v)) || + (size(AtomThrShapeMNK{}) == 2 && + (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopy - invalid TMA copy atom specified."); + + using TmaInternalElementA = cute::conditional_t, cutlass::tfloat32_t, ElementAMma>; + using TmaInternalElementB = cute::conditional_t, cutlass::tfloat32_t, ElementBMma>; + + using SmemAllocTypeA = cute::conditional_t < 8, uint8_t, ElementAMma>; + using SmemAllocTypeB = cute::conditional_t < 8, uint8_t, ElementBMma>; + + using BitTypeElementA = cute::uint_bit_t>; + using BitTypeElementB = cute::uint_bit_t>; + + using ArrayElementA = cute::conditional_t; + using ArrayElementB = cute::conditional_t; + + using RuntimeDataTypeA = cute::conditional_t; + using RuntimeDataTypeB = cute::conditional_t; + + using SmemLayoutScaleA = decltype(make_layout( + append(shape(SmemLayoutAtomSFA{}), Int{}), + append(stride(SmemLayoutAtomSFA{}), size(filter_zeros(SmemLayoutAtomSFA{}))) + )); + using SmemLayoutScaleB = decltype(make_layout( + append(shape(SmemLayoutAtomSFB{}), Int{}), + append(stride(SmemLayoutAtomSFB{}), size(filter_zeros(SmemLayoutAtomSFB{}))) + )); + + // Scaling gmem-to-smem copy atom + static constexpr int LeadingScalesPerTileSFA = size<0,1>(LayoutSFA{}.stride()) == 1 ? ScaleMsPerTile : ScaleKsPerTile; + using ScaleCopyTypeA = cute::uint_byte_t(sizeof(ElementAccumulator)) * LeadingScalesPerTileSFA, 16)>; + using SmemScalingCopyAtomA = Copy_Atom, ElementAccumulator>; + + static constexpr int LeadingScalesPerTileSFB = size<0,1>(LayoutSFB{}.stride()) == 1 ? ScaleNsPerTile : ScaleKsPerTile; + using ScaleCopyTypeB = cute::uint_byte_t(sizeof(ElementAccumulator)) * LeadingScalesPerTileSFB, 16)>; + using SmemScalingCopyAtomB = Copy_Atom, ElementAccumulator>; + + using TiledCopyScaleA = decltype(make_tiled_copy(SmemScalingCopyAtomA{}, Layout>{}, Layout>>{})); + using TiledCopyScaleB = decltype(make_tiled_copy(SmemScalingCopyAtomB{}, Layout>{}, Layout>>{})); + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine> smem_scale_A; + cute::ArrayEngine> smem_scale_B; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + + using Load2TransformPipelineStorage = typename Load2TransformPipeline::SharedStorage; + Load2TransformPipelineStorage transform2load_pipeline; + + using Mma2TransformPipelineStorage = typename Mma2TransformPipeline::SharedStorage; + Mma2TransformPipelineStorage mma2transform_pipeline; + }; + + // Expose shared storage for tensors/pipelines separately to allow kernel layer to reorder them. + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + using Mma2TransformPipelineStorage = typename SharedStorage::Mma2TransformPipelineStorage; + using Load2TransformPipelineStorage = typename SharedStorage::Load2TransformPipelineStorage; + + // Only one thread issues the TMA and updates the barriers in a 2SM MMA, adjust bytes accordingly + static constexpr uint32_t TmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutA{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(SmemLayoutB{})) * cute::sizeof_bits_v); + + template + struct TmemStorage { + AccTensor accumulators; + }; + + template< + class KTileCount, + class GTensorPartitionedA, class GTensorPartitionedB, + class STensorA, class STensorB, + class GTensorPartitionedScaleA, class GTensorPartitionedScaleB, + class IdentTensorPartitionedScaleA, class IdentTensorPartitionedScaleB, + class STensorScaleA, class STensorScaleB + > + struct LoadParams { + // for scheduler + KTileCount k_tiles; + // for input tensor values + GTensorPartitionedA tAgA_mkl; + GTensorPartitionedB tBgB_nkl; + STensorA tAsA; + STensorB tBsB; + + GTensorPartitionedScaleA tSFAgSFA_mkl; + GTensorPartitionedScaleB tSFBgSFB_nkl; + IdentTensorPartitionedScaleA tSFAIdentSFA_mkl; + IdentTensorPartitionedScaleB tSFBIdentSFB_nkl; + STensorScaleA tSFAsSFA; + STensorScaleB tSFBsSFB; + + // the TMA multicast masks + uint16_t mcast_mask_a; + uint16_t mcast_mask_b; + + LayoutSFA layout_SFA; + LayoutSFB layout_SFB; + + CUTLASS_DEVICE + LoadParams ( + KTileCount k_tiles_, + GTensorPartitionedA tAgA_mkl_, GTensorPartitionedB tBgB_nkl_, + STensorA tAsA_, STensorB tBsB_, + GTensorPartitionedScaleA tSFAgSFA_mkl_, GTensorPartitionedScaleB tSFBgSFB_nkl_, + IdentTensorPartitionedScaleA tSFAIdentSFA_mkl_, IdentTensorPartitionedScaleB tSFBIdentSFB_nkl_, + STensorScaleA tSFAsSFA_, STensorScaleB tSFBsSFB_, + uint16_t mcast_mask_a_, uint16_t mcast_mask_b_, + LayoutSFA layout_SFA_, LayoutSFB layout_SFB_) + : k_tiles(k_tiles_) + , tAgA_mkl(tAgA_mkl_), tBgB_nkl(tBgB_nkl_) + , tAsA(tAsA_), tBsB(tBsB_) + , tSFAgSFA_mkl(tSFAgSFA_mkl_), tSFBgSFB_nkl(tSFBgSFB_nkl_) + , tSFAIdentSFA_mkl(tSFAIdentSFA_mkl_), tSFBIdentSFB_nkl(tSFBIdentSFB_nkl_) + , tSFAsSFA(tSFAsSFA_), tSFBsSFB(tSFBsSFB_) + , mcast_mask_a(mcast_mask_a_), mcast_mask_b(mcast_mask_b_) + , layout_SFA(layout_SFA_), layout_SFB(layout_SFB_) {} + }; + + template + struct MmaParams { + TiledMma tiled_mma; + FragmentA tCrA; + FragmentB tCrB; + + CUTLASS_DEVICE + MmaParams ( + TiledMma tiled_mma_, + FragmentA tCrA_, FragmentB tCrB_) + : tiled_mma(tiled_mma_) + , tCrA(tCrA_), tCrB(tCrB_) {} + }; + + template< + class STensorScaleA, class STensorScaleB + > + struct TransformParams { + // for scheduler + + STensorScaleA sSFA; + STensorScaleB sSFB; + + CUTLASS_DEVICE + TransformParams ( + STensorScaleA sSFA_, STensorScaleB sSFB_) + : sSFA(sSFA_), sSFB(sSFB_) {} + }; + + + // Host side kernel arguments + struct Arguments { + ArrayElementA const* ptr_A{nullptr}; + StrideA dA{}; + ArrayElementB const* ptr_B{nullptr}; + StrideB dB{}; + ElementAccumulator const* ptr_scale_A{nullptr}; + LayoutSFA layout_SFA{}; + ElementAccumulator const* ptr_scale_B{nullptr}; + LayoutSFB layout_SFB{}; + RuntimeDataTypeA runtime_data_type_a{}; + RuntimeDataTypeB runtime_data_type_b{}; + }; + + // Device side kernel params + struct Params { + using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(conditional_return(make_shape(uint32_t(0), uint32_t(0), Int<1>{}), ClusterShape{})), + make_tile(typename TiledMma::AtomThrID{}))); + + using TMA_A = decltype(make_tma_atom_A_sm100( + GmemTiledCopyA{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + using TMA_B = decltype(make_tma_atom_B_sm100( + GmemTiledCopyB{}, + make_tensor(recast_ptr(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + ClusterLayout_VMNK{}) + ); + + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_A tma_load_a_fallback; + TMA_B tma_load_b_fallback; + dim3 cluster_shape_fallback; + RuntimeDataTypeA runtime_data_type_a; + RuntimeDataTypeB runtime_data_type_b; + + ElementAccumulator const* ptr_scale_A; + LayoutSFA layout_SFA; + ElementAccumulator const* ptr_scale_B; + LayoutSFB layout_SFB; + }; + + CUTLASS_DEVICE + CollectiveMma(Params const& params, ClusterShape cluster_shape, uint32_t block_rank_in_cluster) + : cluster_shape_(cluster_shape) + , block_rank_in_cluster_(block_rank_in_cluster) + , runtime_data_type_a_(params.runtime_data_type_a) + , runtime_data_type_b_(params.runtime_data_type_b) { + if constexpr (IsDynamicCluster) { + const bool is_fallback_cluster = (cute::size<0>(cluster_shape_) == params.cluster_shape_fallback.x && + cute::size<1>(cluster_shape_) == params.cluster_shape_fallback.y); + observed_tma_load_a_ = is_fallback_cluster ? ¶ms.tma_load_a_fallback : ¶ms.tma_load_a; + observed_tma_load_b_ = is_fallback_cluster ? ¶ms.tma_load_b_fallback : ¶ms.tma_load_b; + } + else { + observed_tma_load_a_ = ¶ms.tma_load_a; + observed_tma_load_b_ = ¶ms.tma_load_b; + } + } + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace, + cutlass::KernelHardwareInfo const& hw_info = cutlass::KernelHardwareInfo{}) { + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = recast_ptr(args.ptr_A); + auto ptr_B = recast_ptr(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape); + + // Cluster layout for TMA construction + auto cluster_layout_vmnk = tiled_divide(make_layout(cluster_shape), make_tile(typename TiledMma::AtomThrID{})); + auto cluster_shape_fallback = cutlass::detail::select_cluster_shape(ClusterShape{}, hw_info.cluster_shape_fallback); + auto cluster_layout_vmnk_fallback = tiled_divide(make_layout(cluster_shape_fallback), make_tile(typename TiledMma::AtomThrID{})); + typename Params::TMA_A tma_load_a = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_B tma_load_b = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk); + + typename Params::TMA_A tma_load_a_fallback = make_tma_atom_A_sm100( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + typename Params::TMA_B tma_load_b_fallback = make_tma_atom_B_sm100( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_,cute::Int<0>{}), + TileShape{}, + TiledMma{}, + cluster_layout_vmnk_fallback); + + return { + tma_load_a, + tma_load_b, + tma_load_a_fallback, + tma_load_b_fallback, + hw_info.cluster_shape_fallback, + args.runtime_data_type_a, + args.runtime_data_type_b, + args.ptr_scale_A, + args.layout_SFA, + args.ptr_scale_B, + args.layout_SFB + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + static constexpr bool IsF8F6F4 = detail::is_sm100_mma_f8f6f4(); + constexpr int tma_alignment_bits_A = cutlass::detail::get_input_alignment_bits(); + constexpr int tma_alignment_bits_B = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits_A / cute::sizeof_bits::value; + + bool implementable = true; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits_B / cute::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + bool implementable_sf = cutlass::detail::check_alignment(args.layout_SFA); + implementable_sf = implementable_sf && cutlass::detail::check_alignment(args.layout_SFB); + + if (!implementable_sf) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for Scale Factors.\n"); + } + + return implementable && implementable_sf; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE void + prefetch_tma_descriptors() { + cute::prefetch_tma_descriptor(observed_tma_load_a_->get_tma_descriptor()); + cute::prefetch_tma_descriptor(observed_tma_load_b_->get_tma_descriptor()); + } + + /// Construct A Single Stage's Accumulator Shape + CUTLASS_DEVICE static + auto + partition_accumulator_shape() { + auto acc_shape = partition_shape_C(TiledMma{}, take<0,2>(TileShape{})); // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N) + + return acc_shape; + } + + template + CUTLASS_DEVICE static + auto + slice_accumulator(TmemStorage tmem_storage, int stage) { + return cute::make_tuple(tmem_storage.accumulators(_,_,_,stage)); + } + + template + CUTLASS_DEVICE static + auto + init_tmem_tensors(EpilogueTile epi_tile) { + TiledMma tiled_mma; + auto acc_shape = partition_accumulator_shape(); + // ((MMA_TILE_M,MMA_TILE_N),MMA_M,MMA_N,ACC_PIPE) where ACC_PIPE=2 so we can double buffer our accumulators for mainloop and epilogue. + Tensor accumulators = cutlass::detail::make_sm100_accumulator( + tiled_mma, acc_shape, EpilogueTile{}); + TmemStorage tmem_storage; + tmem_storage.accumulators = accumulators; + return tmem_storage; + } + + template + CUTLASS_DEVICE static + void + set_tmem_offsets(TmemStorage& tmem_storage, uint32_t tmem_base_addr) { + tmem_storage.accumulators.data() = tmem_base_addr; + } + + /// Set up the data needed by this collective for load. + /// Return load params containing + /// gA_mkl - The tiled tma tensor for input A + /// gB_nkl - The tiled tma tensor for input B + /// tAsA - partitioned smem tensor for A + /// tBsB - partitioned smem tensor for B + /// tSFAgSFA_mkl - partitioned gmem tensor for SFA + /// tSFBgSFB_nkl - partitioned gmem tensor for SFB + /// tSFAIdentSFA_mkl - partitioned identity tensor for SFA in gmem + /// tSFBIdentSFB_nkl - partitioned identity tensor for SFB in gmem + /// tSFAsSFA - partitioned smem tensor for SFA + /// tSFBsSFB - partitioned smem tensor for SFB + /// mcast_mask_a - tma multicast mask for A + /// mcast_mask_b - tma multicast mask for B + /// layout_SFA - layout of SFA in gmem + /// layout_SFB - layout of SFB in gmem + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,L)); + Tensor mB_nkl = observed_tma_load_b_->get_tma_tensor(make_shape(N,K,L)); + + // Tile the tensors and defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + // Partition for this CTA + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); + + Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) + Tensor tCgB_nkl = cta_mma.partition_B(gB_nkl); // (MMA, MMA_N, MMA_K, n, k, l) + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (MMA,MMA_N,MMA_K,PIPE) + + // Define the CTA-in-cluster Layout and Coord + Layout cta_layout_mnk = make_layout(cluster_shape_); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster_); + + // Project the cta_layout for tma_a along the n-modes + auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); + + // Project the cta_layout for tma_b along the m-modes + auto [tBgB_nkl, tBsB] = tma_partition(*observed_tma_load_b_, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(sB), group_modes<0,3>(tCgB_nkl)); + + // TMA Multicast Masks + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + // Scales + + Tensor mSFA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), mainloop_params.layout_SFA); // (m,k,l) + Tensor mSFB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), mainloop_params.layout_SFB); // (n,k,l) + + Tensor SFA_mkl_ident = make_identity_tensor(shape(mainloop_params.layout_SFA)); + + Tensor SFB_nkl_ident = make_identity_tensor(shape(mainloop_params.layout_SFB)); + + // Tile the tensors and defer the slice + Tensor gSFA_mkl = local_tile(mSFA_mkl, CtaShape_MNK{}, + make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor gSFB_nkl = local_tile(mSFB_nkl, CtaShape_MNK{}, + make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + Tensor identSFA_mkl = local_tile(SFA_mkl_ident, CtaShape_MNK{}, + make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) + Tensor identSFB_nkl = local_tile(SFB_nkl_ident, CtaShape_MNK{}, + make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N, BLK_K, n, k, l) + + static_assert(rank(decltype(gSFA_mkl){}) == 5); + static_assert(rank(decltype(gSFB_nkl){}) == 5); + + // 1 thread copies entire set of scalar + TiledCopyScaleA scale_copy_a{}; + TiledCopyScaleB scale_copy_b{}; + + ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(_0{}); + ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(_0{}); + + Tensor sSFA = make_tensor(make_smem_ptr(shared_tensors.smem_scale_A.begin()), + SmemLayoutScaleA{}); // (CTA_M,CTA_K,P) + Tensor sSFB = make_tensor(make_smem_ptr(shared_tensors.smem_scale_B.begin()), + SmemLayoutScaleB{}); // (CTA_M,CTA_K,P) + + Tensor tSFAgSFA_mkl = thr_scale_copy_a.partition_S(gSFA_mkl); // (CPY, BLK_M, BLK_K, m, k, l) + Tensor tSFAIdentSFA_mkl = thr_scale_copy_a.partition_S(identSFA_mkl); // (CPY, BLK_M, BLK_K, m, k, l) + + Tensor tSFAsSFA = thr_scale_copy_a.partition_D(sSFA); + + Tensor tSFBgSFB_nkl = thr_scale_copy_b.partition_S(gSFB_nkl); // (CPY, BLK_N, BLK_K, m, k, l) + Tensor tSFBIdentSFB_nkl = thr_scale_copy_b.partition_S(identSFB_nkl); // (CPY, BLK_N, BLK_K, m, k, l) + Tensor tSFBsSFB = thr_scale_copy_b.partition_D(sSFB); + + static_assert(rank(decltype(tSFAgSFA_mkl){}) == 6); + static_assert(rank(decltype(tSFBgSFB_nkl){}) == 6); + + LoadParams load_params { + shape<3>(gA_mkl), // for scheduler + tAgA_mkl, tBgB_nkl, tAsA, tBsB, // for input tensor values + tSFAgSFA_mkl, tSFBgSFB_nkl, // for input scale tensor values + tSFAIdentSFA_mkl, tSFBIdentSFB_nkl, // for predicating scale tensor copies + tSFAsSFA, tSFBsSFB, // for scale tensor values + mcast_mask_a, mcast_mask_b, // multicast masks + mainloop_params.layout_SFA, // for predicating scale tensor copies + mainloop_params.layout_SFB // for predicating scale tensor copies + }; + return load_params; + } + + /// Set up the data needed by this collective for mma compute. + template + CUTLASS_DEVICE auto + mma_init( + [[maybe_unused]] TmemStorage tmem_tensors, + TensorStorage& shared_tensors) const { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Allocate "fragments/descriptors" for A and B matrices + Tensor tCrA = TiledMma::make_fragment_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = TiledMma::make_fragment_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<3>(sB)); + + TiledMma tiled_mma; + + if constexpr (IsRuntimeDataType) { + // Update instruction descriptor according to runtime argument. + // Applying bitmask (0b111) to help compiler deduce that the conversion and assignment are safe. + tiled_mma.idesc_.a_format_ = uint8_t(runtime_data_type_a_) & 0b111; + tiled_mma.idesc_.b_format_ = uint8_t(runtime_data_type_b_) & 0b111; + } + MmaParams mma_params { + tiled_mma, + tCrA, tCrB + }; + return mma_params; + } + + /// Set up the data needed by this collective for transform. + template + CUTLASS_DEVICE auto + transform_init( + ProblemShape_MNKL const& problem_shape_MNKL, + TensorStorage& shared_tensors) const { + using X = Underscore; + + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + Tensor sSFA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.begin()), + SmemLayoutScaleA{}); // (ScaleMsPerTile,ScakeKsPerTile,P) + Tensor sSFB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.begin()), + SmemLayoutScaleB{}); // (ScaleNsPerTile,ScaleKsPerTile,P) + + + TransformParams transform_params { + sSFA, sSFB // for input tensor values + }; + return transform_params; + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class LoadParams, + class TileCoordMNKL, + class KTileIterator + > + CUTLASS_DEVICE auto + load( + MainloopPipeline mainloop_pipeline, + Load2TransformPipeline load2transform_pipeline, + MainloopPipelineState mainloop_pipe_producer_state, + Load2TransformPipelineState load2transform_pipe_producer_state, + LoadParams const& load_inputs, + TileCoordMNKL const& cta_coord_mnkl, + KTileIterator k_tile_iter, int k_tile_count) { + + auto [unused_k_tiles, + tAgA_mkl, tBgB_nkl, tAsA, tBsB, + tSFAgSFA_mkl, tSFBgSFB_nkl, + tSFAIdentSFA_mkl, tSFBIdentSFB_nkl, + tSFAsSFA, tSFBsSFB, + mcast_mask_a, mcast_mask_b, + layout_SFA, layout_SFB] = load_inputs; + + // slice out the work coord from partitioned tensors + Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); + Tensor tBgB = tBgB_nkl(_, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + TiledCopyScaleA scale_copy_a{}; + TiledCopyScaleB scale_copy_b{}; + + Tensor tSFAgSFA = tSFAgSFA_mkl(_, _, _, get<0>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + Tensor tSFBgSFB = tSFBgSFB_nkl(_, _, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + Tensor thr_tile_SFA_k = tSFAIdentSFA_mkl(_0{}, _, _, get<0>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor thr_tile_pSFA = make_tensor(shape(filter_zeros(thr_tile_SFA_k(_,_,_0{}), tSFAgSFA(_0{},_,_,_0{}).stride()))); + Tensor thr_tile_SFB_k = tSFBIdentSFB_nkl(_0{}, _, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + + Tensor thr_tile_pSFB = make_tensor(shape(filter_zeros(thr_tile_SFB_k(_,_,_0{}), tSFBgSFB(_0{},_,_,_0{}).stride()))); + + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + // Issue the Mainloop loads + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // LOCK mainloop_pipe_producer_state for _writing_ + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + + load2transform_pipeline.producer_acquire(load2transform_pipe_producer_state); + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + + int write_stage = mainloop_pipe_producer_state.index(); + auto curr_mainloop_pipe_producer_state = mainloop_pipe_producer_state; + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(thr_tile_pSFA); ++i) { + Tensor thr_tile_SFA = filter_zeros(thr_tile_SFA_k(_,_,*k_tile_iter), tSFAgSFA(_0{},_,_,_0{}).stride()); + thr_tile_pSFA(i) = elem_less(thr_tile_SFA(i), shape(filter_zeros(layout_SFA))); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(thr_tile_pSFB); ++i) { + Tensor thr_tile_SFB = filter_zeros(thr_tile_SFB_k(_,_,*k_tile_iter), tSFBgSFB(_0{},_,_,_0{}).stride()); + thr_tile_pSFB(i) = elem_less(thr_tile_SFB(i), shape(filter_zeros(layout_SFB))); + } + + if (cute::elect_one_sync()) { + copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_b_->with(*tma_barrier, mcast_mask_b), tBgB(_,*k_tile_iter), tBsB(_,write_stage)); + copy_if(scale_copy_a, thr_tile_pSFA, filter_zeros(tSFAgSFA(_,_,_,*k_tile_iter)), filter_zeros(tSFAsSFA(_,_,_,load2transform_pipe_producer_state.index()))); + copy_if(scale_copy_b, thr_tile_pSFB, filter_zeros(tSFBgSFB(_,_,_,*k_tile_iter)), filter_zeros(tSFBsSFB(_,_,_,load2transform_pipe_producer_state.index()))); + load2transform_pipeline.producer_commit(load2transform_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive_noinc); + } + + __syncwarp(); + + ++load2transform_pipe_producer_state; + --k_tile_count; + ++k_tile_iter; + } + + return cute::make_tuple(mainloop_pipe_producer_state, load2transform_pipe_producer_state, k_tile_iter); + } + + /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster + CUTLASS_DEVICE void + load_tail( + MainloopPipeline mainloop_pipeline, + Load2TransformPipeline load2transform_pipeline, + MainloopPipelineState mainloop_pipe_producer_state, + Load2TransformPipelineState load2transform_pipe_producer_state) { + // Issue the epilogue waits + // This helps avoid early exit of ctas in Cluster + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then would just be acquired since the phase was + // still inverted from make_producer_start_state + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + load2transform_pipeline.producer_tail(load2transform_pipe_producer_state); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class TmemStorage, + class MmaParams, + class CtaTileCoord + > + CUTLASS_DEVICE auto + mma( + cute::tuple pipelines, + cute::tuple pipeline_states, + TmemStorage tmem_storage, + MmaParams const& mma_inputs, + CtaTileCoord cta_tile_coord, + int k_tile_count) { + auto [tiled_mma, tCrA, tCrB] = mma_inputs; + + auto [mainloop_pipeline, + mma2transform_pipeline] = pipelines; + + auto [mainloop_pipe_consumer_state, + mma2transform_pipe_producer_state] = pipeline_states; + + uint32_t skip_wait = k_tile_count <= 0; + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + // + // PIPELINED MAIN LOOP + // + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + // WAIT on mainloop_pipe_consumer_state until its data are available + // (phase bit flips from mainloop_pipe_consumer_state.phase() value) + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state); + + // Compute on k_tile + int read_stage = mainloop_pipe_consumer_state.index(); + // Save current mainlop pipeline read state + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + + // Advance mainloop_pipe + ++mainloop_pipe_consumer_state; + --k_tile_count; + skip_wait = k_tile_count <= 0; + // Peek at next iteration + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + + static_assert(size<2>(tCrA) / K_BLOCK_MMAS_PER_SCALE_K, "k blocks must be divisible by K_BLOCK_MMAS_PER_SCALE_K"); + + CUTLASS_PRAGMA_UNROLL + for (int scale_k_blocks = 0; scale_k_blocks < size<2>(tCrA) / K_BLOCK_MMAS_PER_SCALE_K; ++scale_k_blocks) { + mma2transform_pipeline.producer_acquire(mma2transform_pipe_producer_state); + + auto acc = get<0>(slice_accumulator(tmem_storage, mma2transform_pipe_producer_state.index())); + static_assert(is_tmem>::value, "Accumulator must be tmem resident."); + static_assert(rank(remove_cvref_t{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); + + // for each set of scale_k_blocks we zero the accumulator + tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; + int start_k_block = scale_k_blocks * size<2>(tCrA) / K_BLOCK_MMAS_PER_SCALE_K; + // Unroll the K mode manually so we can set scale C to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block_offset = 0; k_block_offset < K_BLOCK_MMAS_PER_SCALE_K; ++k_block_offset) { + int k_block = start_k_block + k_block_offset; + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, + tCrA(_,_,k_block,read_stage), + tCrB(_,_,k_block,read_stage), + acc); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + mma2transform_pipeline.producer_commit(mma2transform_pipe_producer_state); + ++mma2transform_pipe_producer_state; + } + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + + } + + return make_tuple(mainloop_pipe_consumer_state, mma2transform_pipe_producer_state); + } + + /// Transform + template < + class TransformParams, + class TmemStorage, + class CtaTileCoord, + class CopyOpT2R, + class EpilogueTile + > + CUTLASS_DEVICE auto + transform( + cute::tuple pipelines, + cute::tuple consumer_states, + TmemStorage tmem_storage, + TransformParams const& transform_inputs, + CtaTileCoord cta_tile_coord, + CopyOpT2R, + EpilogueTile, + int k_tile_count) { + + static_assert(size<0>(EpilogueTile{}) <= size<0>(CtaShape_MNK{}), "Restrict epilogue tile to be smaller than or equal to CTA Tile"); + static_assert(size<1>(EpilogueTile{}) <= size<1>(CtaShape_MNK{}), "Restrict epilogue tile to be smaller than or equal to CTA Tile"); + + + // + // PIPELINED Transform + // + + Tensor acc = get<0>(slice_accumulator(tmem_storage, _0{})); + + Tensor tAcc = acc(make_coord(_,_),_0{},_0{}); + + Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + // Append N with a stride of 0 to SFA + Tensor sSFA_ = transform_inputs.sSFA; + Tensor sSFA = make_tensor(sSFA_.data(), make_layout( + make_shape(get<0>(sSFA_.shape()), get<1>(CtaShape_MNK{}), get<1>(sSFA_.shape()), get<2>(sSFA_.shape())), + make_stride(get<0>(sSFA_.stride()), _0{}, get<1>(sSFA_.stride()), get<2>(sSFA_.stride())) + )); + + CUTE_STATIC_ASSERT_V(size<0>(sSFA) == size<0>(tAcc)); + CUTE_STATIC_ASSERT_V(size<1>(sSFA) == size<1>(tAcc)); + + Tensor sSFA_epi = flat_divide(sSFA, EpilogueTile{}); + + // Append M with a stride of 0 to SFB + Tensor sSFB_ = transform_inputs.sSFB; + Tensor sSFB = make_tensor(sSFB_.data(), make_layout( + make_shape(get<0>(CtaShape_MNK{}), get<0>(sSFB_.shape()), get<1>(sSFB_.shape()), get<2>(sSFB_.shape())), + make_stride(_0{}, get<0>(sSFB_.stride()), get<1>(sSFB_.stride()), get<2>(sSFB_.stride())) + )); + + CUTE_STATIC_ASSERT_V(size<0>(sSFB) == size<0>(tAcc)); + CUTE_STATIC_ASSERT_V(size<1>(sSFB) == size<1>(tAcc)); + + Tensor sSFB_epi = flat_divide(sSFB, EpilogueTile{}); + + TiledCopy tiled_t2r_epi = make_tmem_copy(CopyOpT2R{}, tAcc_epi(_,_,_0{},_0{})); + + int thread_idx = threadIdx.x % size(tiled_t2r_epi); + + ThrCopy thread_t2r_epi = tiled_t2r_epi.get_slice(thread_idx); + + Tensor acc_ident_epi = make_identity_tensor(shape(tAcc_epi)); + + Tensor tTR_rAcc_epi = thread_t2r_epi.partition_D(acc_ident_epi); // (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + + Tensor tTR_sSFA_epi = thread_t2r_epi.partition_D(sSFA_epi); // (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + Tensor tTR_sSFB_epi = thread_t2r_epi.partition_D(sSFB_epi); // (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + + static_assert(rank(decltype(tTR_sSFA_epi){}) == 7); + + Tensor tTR_FullAcc = make_tensor(shape(tTR_rAcc_epi)); + Tensor tTR_PartAcc = make_tensor(shape(tTR_rAcc_epi(_,_,_,_0{},_0{}))); + + Tensor tTR_rSFA_compact = make_fragment_like(filter_zeros(tTR_sSFA_epi(_,_,_,_,_,_,_0{}))); + Tensor tTR_rSFB_compact = make_fragment_like(filter_zeros(tTR_sSFB_epi(_,_,_,_,_,_,_0{}))); + + Layout tTR_rSFA_layout = make_layout(tTR_sSFA_epi(_,_,_,_,_,_,_0{}).shape(), tTR_rSFA_compact.stride()); + Layout tTR_rSFB_layout = make_layout(tTR_sSFB_epi(_,_,_,_,_,_,_0{}).shape(), tTR_rSFB_compact.stride()); + + // Zero our accumulator + clear(tTR_FullAcc); + + auto [mma2transform_pipeline, load2transform_pipeline] = pipelines; + auto [mma2transform_pipe_state, load2transform_pipe_state] = consumer_states; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + load2transform_pipeline.consumer_wait(load2transform_pipe_state); + int read_idx = load2transform_pipe_state.index(); + + copy(filter_zeros(tTR_sSFA_epi(_,_,_,_,_,_,read_idx)), tTR_rSFA_compact); + copy(filter_zeros(tTR_sSFB_epi(_,_,_,_,_,_,read_idx)), tTR_rSFB_compact); + + CUTE_STATIC_ASSERT_V(cosize(tTR_rSFA_layout) == size(tTR_rSFA_compact)); + CUTE_STATIC_ASSERT_V(cosize(tTR_rSFB_layout) == size(tTR_rSFB_compact)); + + Tensor tTR_rSFA = make_tensor(tTR_rSFA_compact.data(), tTR_rSFA_layout); + Tensor tTR_rSFB = make_tensor(tTR_rSFB_compact.data(), tTR_rSFB_layout); + + load2transform_pipeline.consumer_release(load2transform_pipe_state); + ++load2transform_pipe_state; + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < ScaleKsPerTile; ++k_block) { + + mma2transform_pipeline.consumer_wait(mma2transform_pipe_state); + + Tensor acc = get<0>(slice_accumulator(tmem_storage, mma2transform_pipe_state.index())); + Tensor tAcc = acc(make_coord(_,_),_0{},_0{}); + Tensor tAcc_epi = flat_divide(tAcc, EpilogueTile{}); // (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N) + Tensor tTR_tAcc = thread_t2r_epi.partition_S(tAcc_epi); // (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < size<2>(tAcc_epi); ++epi_m) { + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < size<3>(tAcc_epi); ++epi_n) { + + auto scale_a = tTR_rSFA(_,_,_,epi_m,epi_n,k_block * ScaleGranularityK); + auto scale_b = tTR_rSFB(_,_,_,epi_m,epi_n,k_block * ScaleGranularityK); + + Tensor full_acc = tTR_FullAcc(_,_,_,epi_m,epi_n); + // Compute tmem load predication if necessary + copy(tiled_t2r_epi, tTR_tAcc(_,_,_,epi_m,epi_n), tTR_PartAcc); + cutlass::arch::fence_view_async_tmem_load(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(full_acc); ++i) { + ElementAccumulator scale = scale_a(i) * scale_b(i); + full_acc(i) += scale * tTR_PartAcc(i); + } + } + } + cutlass::arch::fence_view_async_tmem_load(); + mma2transform_pipeline.consumer_release(mma2transform_pipe_state); + // release acc + ++mma2transform_pipe_state; + } + + --k_tile_count; + } + + return cute::make_tuple(tTR_FullAcc, tiled_t2r_epi, cute::make_tuple(mma2transform_pipe_state, load2transform_pipe_state)); + } + +protected: + + typename Params::TMA_A const* observed_tma_load_a_{nullptr}; + typename Params::TMA_B const* observed_tma_load_b_{nullptr}; + + RuntimeDataTypeA runtime_data_type_a_{}; + RuntimeDataTypeB runtime_data_type_b_{}; + + ClusterShape cluster_shape_; + uint32_t block_rank_in_cluster_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp index 22a7d4bed..15384bbc3 100644 --- a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -204,6 +204,8 @@ public: using PipelineState = cutlass::PipelineState; using PipelineParams = typename MainloopPipeline::Params; + static constexpr int NumProducerThreadEvents = 1; + using SmemLayoutAtomScale = Layout(SwappedSmemLayoutAtomA{})), cute::Int<1>>>; using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}), shape<1>(SmemLayoutAtomScale{}))); @@ -1354,6 +1356,18 @@ public: static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in tensormaps_fence_acquire."); } } + + template + CUTLASS_DEVICE + InputTensors + tensors_perform_update( + InputTensors const& input_tensors, + [[maybe_unused]] Params const& mainloop_params, + [[maybe_unused]] ProblemShape_MNKL problem_shape_mnkl, + [[maybe_unused]] int32_t next_batch) { + return input_tensors; + } + }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp index 789a3cb15..61774102a 100644 --- a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp @@ -116,6 +116,9 @@ struct CollectiveMma< using PipelineParams = typename MainloopPipeline::Params; using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + + static constexpr int NumProducerThreadEvents = 1; + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); @@ -749,6 +752,16 @@ struct CollectiveMma< cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps)); } + template + CUTLASS_DEVICE + InputTensors + tensors_perform_update( + InputTensors const& input_tensors, + [[maybe_unused]] Params const& mainloop_params, + [[maybe_unused]] ProblemShape_MNKL problem_shape_mnkl, + [[maybe_unused]] int32_t next_batch) { + return input_tensors; + } }; diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp index b3d857ff3..f6aae992b 100644 --- a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8.hpp @@ -759,6 +759,18 @@ struct CollectiveMma< cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps)); } + + template + CUTLASS_DEVICE + InputTensors + tensors_perform_update( + InputTensors const& input_tensors, + [[maybe_unused]] Params const& mainloop_params, + [[maybe_unused]] ProblemShape_MNKL problem_shape_mnkl, + [[maybe_unused]] int32_t next_batch) { + return input_tensors; + } + }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp new file mode 100644 index 000000000..a55059ad8 --- /dev/null +++ b/include/cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp @@ -0,0 +1,1036 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" +#include "cutlass/cuda_host_adapter.hpp" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm80.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule, + int ScaleGranularityM_, + int ScaleGranularityN_, + int ScalePromotionInterval_, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using InternalStrideA = cute::remove_pointer_t; + using ElementB = ElementB_; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using ElementBlockScale = ElementAccumulator; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + + static constexpr int NumProducerThreadEvents = 2; + + static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_; + static constexpr int ScaleGranularityN = ScaleGranularityN_ == 0 ? size<1>(TileShape{}) : ScaleGranularityN_; + static constexpr int ScalePromotionInterval = ScalePromotionInterval_; + static_assert(ScalePromotionInterval % 4 == 0, "ScalePromotionInterval must be a multiple of 4."); + + static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; + static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M."); + static_assert((size<1>(TileShape{}) % ScaleGranularityN) == 0, "FP8 scaling granularity must evenly divide tile shape along N."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + // Block scaling gmem-to-smem copy atom + using BlockScaleCopyTypeA = cute::uint_byte_t(sizeof(ElementBlockScale)) * ScaleMsPerTile, 16)>; + using BlockScaleCopyTypeB = cute::uint_byte_t(sizeof(ElementBlockScale)) * ScaleNsPerTile, 16)>; + using SmemBlockScalingCopyAtomA = Copy_Atom, ElementBlockScale>; + using SmemBlockScalingCopyAtomB = Copy_Atom, ElementBlockScale>; + + // Block scaling smem layout + using SmemLayoutScaleA = Layout, Int>>; + using SmemLayoutScaleB = Layout, Int>>; + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using InternalElementA = cute::conditional_t>>; + using InternalElementB = cute::conditional_t>>; + + static_assert(cute::is_same_v, + "ElementAccumulator and ElementBlockScale should be same datatype"); + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + cute::array_aligned> smem_scale_A; + cute::array_aligned> smem_scale_B; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_A; + cute::TmaDescriptor smem_tensormap_B; + } tensormaps; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + // Host side kernel arguments + struct Arguments { + ElementA const** ptr_A; + StrideA dA; + ElementB const** ptr_B; + StrideB dB; + ElementBlockScale const** ptr_scale_A; + ElementBlockScale const** ptr_scale_B; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy_A_sm90( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}), + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy_B_sm90( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}), + SmemLayoutB{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + void* tensormaps; + InternalElementA const** ptr_A; + StrideA dA; + InternalElementB const** ptr_B; + StrideB dB; + // Block scaling factors for A and B + ElementBlockScale const** ptr_scale_A; + ElementBlockScale const** ptr_scale_B; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape problem_shapes, + Arguments const& args, + void* workspace) { + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_shape = repeat_like(typename ProblemShape::UnderlyingProblemShape{}, int32_t(1)); + auto init_M = get<0>(init_shape); + auto init_N = get<1>(init_shape); + auto init_K = get<2>(init_shape); + // Batches/Groups are managed by using appropriate pointers to input matrices + const uint32_t mock_L = 1; + InternalElementA const* ptr_A_first_batch = reinterpret_cast(args.ptr_A); + InternalElementB const* ptr_B_first_batch = reinterpret_cast(args.ptr_B); + + InternalStrideA stride_a; + InternalStrideB stride_b; + if constexpr (IsGroupedGemmKernel) { + // Strides for Grouped Gemm will be replaced prior to the first access regardless. + stride_a = InternalStrideA{}; + stride_b = InternalStrideB{}; + } + else { + // Tensor shapes for Ptr-Array are initialized correctly only here. + auto problem_shape_MNK = problem_shapes.get_host_problem_shape(0); + init_M = get<0>(problem_shape_MNK); + init_N = get<1>(problem_shape_MNK); + init_K = get<2>(problem_shape_MNK); + + stride_a = args.dA; + stride_b = args.dB; + } + Tensor tensor_a = make_tensor(ptr_A_first_batch, make_layout(make_shape(init_M,init_K,mock_L), stride_a)); + Tensor tensor_b = make_tensor(ptr_B_first_batch, make_layout(make_shape(init_N,init_K,mock_L), stride_b)); + auto tma_load_a = make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + auto tma_load_b = make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + + void* tensormaps = workspace; + + return { + tma_load_a, + tma_load_b, + TmaTransactionBytes, + tensormaps, + reinterpret_cast(args.ptr_A), + args.dA, + reinterpret_cast(args.ptr_B), + args.dB, + args.ptr_scale_A, + args.ptr_scale_B + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + constexpr uint32_t NumInputTensors = 2; + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return (NumInputTensors * SizeOfCuTensorMap * sm_count); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + ProblemShape problem_shapes, + Arguments const& args) { + bool implementable = true; + constexpr int tma_alignment_bits = 128; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + + if (problem_shapes.is_host_problem_shape_available()) { + // Check alignment for all problem sizes + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); + // We expect full tiles in K + implementable = implementable && K % size<2>(TileShape{}) == 0; + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = DispatchPolicy::PipelineAsyncMmaStages; + static constexpr uint32_t TmaTransactionBytes = + cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value))+ + cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); + + // Set up the data needed by this collective for load and mma. + // Returns a tuple of tensors. The collective and the kernel layer have the contract that the + // returned tuple must contain at least two elements, with the first two elements being: + // gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + // gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + // The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init( + ProblemShape_MNKL const& problem_shape_MNKL, + Params const& mainloop_params, + ElementBlockScale const* ptr_scale_A = nullptr, + ElementBlockScale const* ptr_scale_B = nullptr + ) const { + + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + const int32_t mock_L = 1; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,mock_L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,mock_L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + auto tK = get<3>(gA_mkl.shape()); + + // Make the tiled views of scale tensors + auto scaleA_shape = make_shape(M / ScaleGranularityM, tK, L); // (scale_m,k,l) + auto scaleB_shape = make_shape(N / ScaleGranularityN, tK, L); // (scale_n,k,l) + auto scaleA_layout = make_ordered_layout(scaleA_shape, Step<_0, _1, _2>{}); + auto scaleB_layout = make_ordered_layout(scaleB_shape, Step<_0, _1, _2>{}); + + // Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and + // gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl. + + Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(ptr_scale_A), scaleA_layout); // (scale_m,k,l) + Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(ptr_scale_B), scaleB_layout); // (scale_n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl); + + } + + // Perform a collective-scoped matrix multiply-accumulate + // Producer Perspective + template < + class TensorA, class TensorB, + class TensorMapA, class TensorMapB, + class TensorScaleA, class TensorScaleB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& load_inputs, + cute::tuple const& input_tensormaps, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + // Blockscaling: Tma loads for load_input and CpAsync for load_scale + if (lane_predicate) { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k) + Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (ScaleNsPerTile,k) + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Block scaling: load_scale has scaling tensors in global memory which are not tiled + Tensor mScaleA_mkl = get<2>(load_inputs); + Tensor mScaleB_nkl = get<3>(load_inputs); + + auto scales_m = get<0>(mScaleA_mkl.shape()); + auto scales_n = get<0>(mScaleB_nkl.shape()); + + Tensor gScaleA = local_tile(mScaleA_mkl, make_tile(Int{}), make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1) + Tensor gScaleB = local_tile(mScaleB_nkl, make_tile(Int{}), make_coord(n_coord,_,l_coord)); // (ScaleNsPerTile,k,1) + + TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{}, Layout>{}, Layout>>{}); + TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{}, Layout>{}, Layout>>{}); + + ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x); + ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x); + + Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA); + Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA); + + Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB); + Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB); + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(get<0>(input_tensormaps), *tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(get<1>(input_tensormaps), *tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + + // Copy scale tensors from global memory to shared memory + copy(scale_copy_a, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage)); + copy(scale_copy_b, tBgB_ScaleB(_,_,*k_tile_iter), tBsB_ScaleB(_,_,write_stage)); + pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc); + + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + // Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + // This helps avoid early exit of blocks in Cluster. + // Waits for all stages to either be released (all + // Consumer UNLOCKs), or if the stage was never used + // then it would just be acquired since the phase was + // still inverted from make_producer_start_state. + pipeline.producer_tail(smem_pipe_write); + } + } + + + template< + class EngineAccum, + class LayoutAccum, + class ScaleFactor + > + CUTLASS_DEVICE + void scale_if_needed(GmmaFP8Accumulation& accumulation, ScaleFactor scaleFactor) { + if constexpr (ScalePromotionInterval != 4) { + accumulation.scale_if_needed(scaleFactor); + } + else { + // avoid unnecessary tests when granularity is the finnest + accumulation.scale(scaleFactor); + } + } + + template< + class EngineAccum, + class LayoutAccum, + class ScaleFactor1, + class ScaleFactor2 + > + CUTLASS_DEVICE + void scale_if_needed(GmmaFP8Accumulation& accumulation, ScaleFactor1 scaleFactor1, ScaleFactor2 scaleFactor2) { + if constexpr (ScalePromotionInterval != 4) { + accumulation.scale_if_needed(scaleFactor1, scaleFactor2); + } + else { + // avoid unnecessary tests when granularity is the finnest + accumulation.scale(scaleFactor1, scaleFactor2); + } + } + + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Block scaling + Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), + Layout< + Shape, Int>, cute::tuple_element_t<1, TileShape>, Int>, + Stride, _0, Int> + >{}); // ((ScaleGranularityM,ScaleMsPerTile),n,k) + Tensor sScaleBViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), + Layout< + Shape, Shape, Int>, Int>, + Stride<_0, Stride<_0, _1>, Int> + >{}); // (m,(ScaleGranularityN,ScaleNsPerTile),k) + + // + // Define C accumulators and A/B partitioning + // + + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and + stride<0>(typename TiledMma::BLayout{}) == 0 and + size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and + size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, + Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + + Tensor tCsScaleAViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C. + Tensor tCsScaleBViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleBViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C. + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Per block scale values for operand A and B + Tensor tCrScaleAViewAsC = make_tensor_like(tCsScaleAViewAsC(_, _, _, 0)); // (MMA,MMA_M,MMA_N) + Tensor tCrScaleBViewAsC = make_tensor_like(tCsScaleBViewAsC(_, _, _, 0)); // (MMA,MMA_M,MMA_N) + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + assert(k_tile_count >= 1); + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + // fence_operand(); + GmmaFP8Accumulation accumulation(accum, ScalePromotionInterval, size<2>(tCrA)); + warpgroup_fence_operand(accumulation()); + + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + if constexpr (ScalePromotionInterval != 4) { + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + } + else { + // Always zero out the accumulator for finest granularity + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + int read_stage = smem_pipe_read.index(); + // Load per block scale values from shared memory to registers + copy(tCsScaleAViewAsC(_, _, _, read_stage), tCrScaleAViewAsC); + copy(tCsScaleBViewAsC(_, _, _, read_stage), tCrScaleBViewAsC); + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + tCrScaleAViewAsC.data()[0] = tCrScaleAViewAsC.data()[0] * tCrScaleBViewAsC.data()[0]; + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_b = tCrScaleBViewAsC.data()[0]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrScaleAViewAsC); i++) { + tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b; + } + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + ElementBlockScale scale_a = tCrScaleAViewAsC.data()[0]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrScaleBViewAsC); i++) { + tCrScaleBViewAsC.data()[i] = tCrScaleBViewAsC.data()[i] * scale_a; + } + } + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + + // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` and `tCrScaleBViewAsC` + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_ab = tCrScaleAViewAsC.data()[0]; + scale_if_needed(accumulation, scale_ab); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + scale_if_needed(accumulation, tCrScaleAViewAsC); + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + scale_if_needed(accumulation, tCrScaleBViewAsC); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { + scale_if_needed(accumulation, tCrScaleAViewAsC, tCrScaleBViewAsC); + } + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accumulation()); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + // fence_operand(); + // Load per block scale values from shared memory to registers (at most twice per block along M and/or N) + copy(tCsScaleAViewAsC(_, _, _, read_stage), tCrScaleAViewAsC); + copy(tCsScaleBViewAsC(_, _, _, read_stage), tCrScaleBViewAsC); + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + tCrScaleAViewAsC.data()[0] = tCrScaleAViewAsC.data()[0] * tCrScaleBViewAsC.data()[0]; + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_b = tCrScaleBViewAsC.data()[0]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrScaleAViewAsC); i++) { + tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b; + } + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + ElementBlockScale scale_a = tCrScaleAViewAsC.data()[0]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrScaleBViewAsC); i++) { + tCrScaleBViewAsC.data()[i] = tCrScaleBViewAsC.data()[i] * scale_a; + } + } + + if constexpr (ScalePromotionInterval != 4) { + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + } + else { + // Always zero out the accumulator for finest granularity + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + warpgroup_fence_operand(accumulation()); + + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accumulation()); + + // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` and `tCrScaleBViewAsC` + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_ab = tCrScaleAViewAsC.data()[0]; + scale_if_needed(accumulation, scale_ab); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + scale_if_needed(accumulation, tCrScaleAViewAsC); + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + scale_if_needed(accumulation, tCrScaleBViewAsC); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { + scale_if_needed(accumulation, tCrScaleAViewAsC, tCrScaleBViewAsC); + } + + // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release(smem_pipe_release); + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + if constexpr (ScalePromotionInterval != 4) { + // residues only exists when granularity is not the finnest + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_ab = tCrScaleAViewAsC.data()[0]; + accumulation.scale_residue_if_needed(scale_ab); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + accumulation.scale_residue_if_needed(tCrScaleAViewAsC); + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + accumulation.scale_residue_if_needed(tCrScaleBViewAsC); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { + accumulation.scale_residue_if_needed(tCrScaleAViewAsC, tCrScaleBViewAsC); + } + } + + warpgroup_fence_operand(accumulation()); + + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + CUTLASS_DEVICE auto + tensormaps_init( + Params const& mainloop_params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx) { + cute::TmaDescriptor* gmem_tensormap = reinterpret_cast(mainloop_params.tensormaps); + + cute::TmaDescriptor* tma_desc_a = &gmem_tensormap[sm_idx]; + cute::TmaDescriptor* tma_desc_b = &gmem_tensormap[sm_idx + sm_count]; + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + Tensor pA_tensormap = make_tensor(mainloop_params.tma_load_a.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sA_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_A), Int<1>{}, Int<1>{}); + Tensor pB_tensormap = make_tensor(mainloop_params.tma_load_b.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sB_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_B), Int<1>{}, Int<1>{}); + + copy(recast(pA_tensormap), recast(sA_tensormap)); + copy(recast(pB_tensormap), recast(sB_tensormap)); + } + __syncwarp(); + + return cute::make_tuple(tma_desc_a, tma_desc_b); + } + + // Replace address for the global tensor (to be done by single thread) + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_batch) { + // Replacing global_address for the next batch + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_A, + mainloop_params.ptr_A[next_batch]); + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_B, + mainloop_params.ptr_B[next_batch]); + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + const uint32_t K = get<2>(problem_shape_mnkl); + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape_A = {1,1,1,1,1}; + cute::array prob_stride_A = {0,0,0,0,0}; + cute::array prob_shape_B = {1,1,1,1,1}; + cute::array prob_stride_B = {0,0,0,0,0}; + + InternalElementA const* ptr_A = nullptr; + Tensor tensor_a = make_tensor(ptr_A, make_shape(M,K,Int<1>{}), mainloop_params.dA[next_group]); + + InternalElementB const* ptr_B = nullptr; + Tensor tensor_b = make_tensor(ptr_B, make_shape(N,K,Int<1>{}), mainloop_params.dB[next_group]); + + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_a, tensor_a, + prob_shape_A, prob_stride_A); + cute::detail::fill_tma_gmem_shape_stride(mainloop_params.tma_load_b, tensor_b, + prob_shape_B, prob_stride_B); + + // Convert strides to byte strides + for (uint64_t& stride : prob_stride_A) { + stride = (stride * sizeof_bits_v) / 8; + } + for (uint64_t& stride : prob_stride_B) { + stride = (stride * sizeof_bits_v) / 8; + } + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_A, + prob_shape_A, + prob_stride_A); + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_B, + prob_shape_B, + prob_stride_B); + } + + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormaps, + Params const& mainloop_params, + cute::tuple const& input_tensormaps, + ProblemShape_MNKL problem_shape_mnkl, + int32_t next_batch) { + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormaps, mainloop_params, next_batch); + + if constexpr (IsGroupedGemmKernel) { + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties(shared_tensormaps, + mainloop_params, next_batch, problem_shape_mnkl); + } + } + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release ( + TensorMapStorage& shared_tensormaps, + cute::tuple const& input_tensormaps) { + // Entire warp must do this (i.e. it's aligned) + tma_descriptor_cp_fence_release(get<0>(input_tensormaps), shared_tensormaps.smem_tensormap_A); + tma_descriptor_cp_fence_release(get<1>(input_tensormaps), shared_tensormaps.smem_tensormap_B); + } + + // The entire warp must call this function collectively (that is, the instructions are aligned) + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::tuple const& input_tensormaps) { + cute::tma_descriptor_fence_acquire(get<0>(input_tensormaps)); + cute::tma_descriptor_fence_acquire(get<1>(input_tensormaps)); + } + + template + CUTLASS_DEVICE + InputTensors + tensors_perform_update( + [[maybe_unused]] InputTensors const& input_tensors, + Params const& mainloop_params, + [[maybe_unused]] ProblemShape_MNKL problem_shape_mnkl, + int32_t next_batch) { + + if constexpr (IsGroupedGemmKernel) { + return load_init( + problem_shape_mnkl, + mainloop_params, + mainloop_params.ptr_scale_A[next_batch], + mainloop_params.ptr_scale_B[next_batch] + ); + } else { + auto [gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl] = input_tensors; + + auto scaleA_layout = mScaleA_mkl.layout(); + auto scaleB_layout = mScaleB_nkl.layout(); + + mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A[next_batch]), scaleA_layout); // (m,ScaleMsPerTile,k,l) + mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B[next_batch]), scaleB_layout); // (n,ScaleNsPerTile,k,l) + return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp index 546bf9159..fa00c27e7 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp @@ -59,6 +59,7 @@ template < class KernelSchedule, int ScaleGranularityM_, int ScaleGranularityN_, + int ScalePromotionInterval_, class TileShape_, class ElementA_, class StrideA_, @@ -74,7 +75,7 @@ template < class SmemCopyAtomB_, class TransformB_> struct CollectiveMma< - MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8, + MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8, TileShape_, ElementA_, StrideA_, @@ -93,7 +94,7 @@ struct CollectiveMma< // // Type Aliases // - using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8; + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8; using TileShape = TileShape_; using ElementA = ElementA_; using StrideA = StrideA_; @@ -122,6 +123,8 @@ struct CollectiveMma< static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_; static constexpr int ScaleGranularityN = ScaleGranularityN_ == 0 ? size<1>(TileShape{}) : ScaleGranularityN_; + static constexpr int ScalePromotionInterval = ScalePromotionInterval_; + static_assert(ScalePromotionInterval % 4 == 0, "ScalePromotionInterval must be a multiple of 4."); static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM; static constexpr int ScaleNsPerTile = size<1>(TileShape{}) / ScaleGranularityN; @@ -281,7 +284,9 @@ struct CollectiveMma< constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); /* MMA promotion interval should be a multiple of 4, since each mainloop iteration would issue 4 MMA instructions. */ - implementable = implementable && (args.mma_promotion_interval % 4 == 0); + constexpr int pipe_k = size<2>(TileShape{}) / tile_size<2>(TiledMma{}); + implementable = implementable && (args.mma_promotion_interval % 4 == 0) && (args.mma_promotion_interval == ScalePromotionInterval); + implementable = implementable && (pipe_k % 4 == 0) && (pipe_k <= args.mma_promotion_interval); if (!implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); @@ -481,6 +486,38 @@ struct CollectiveMma< } } + template< + class EngineAccum, + class LayoutAccum, + class ScaleFactor + > + CUTLASS_DEVICE + void scale_if_needed(GmmaFP8Accumulation& accumulation, ScaleFactor scaleFactor) { + if constexpr (ScalePromotionInterval != 4) { + accumulation.scale_if_needed(scaleFactor); + } + else { + // avoid unnecessary tests when granularity is the finnest + accumulation.scale(scaleFactor); + } + } + template< + class EngineAccum, + class LayoutAccum, + class ScaleFactor1, + class ScaleFactor2 + > + CUTLASS_DEVICE + void scale_if_needed(GmmaFP8Accumulation& accumulation, ScaleFactor1 scaleFactor1, ScaleFactor2 scaleFactor2) { + if constexpr (ScalePromotionInterval != 4) { + accumulation.scale_if_needed(scaleFactor1, scaleFactor2); + } + else { + // avoid unnecessary tests when granularity is the finnest + accumulation.scale(scaleFactor1, scaleFactor2); + } + } + /// Perform a collective-scoped matrix multiply-accumulate /// Consumer Perspective template < @@ -575,7 +612,7 @@ struct CollectiveMma< tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - GmmaFP8Accumulation accumulation(accum, mainloop_params.mma_promotion_interval, size<2>(tCrA)); + GmmaFP8Accumulation accumulation(accum, ScalePromotionInterval, size<2>(tCrA)); warpgroup_fence_operand(accumulation()); CUTLASS_PRAGMA_UNROLL for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) @@ -584,7 +621,13 @@ struct CollectiveMma< auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); pipeline.consumer_wait(smem_pipe_read, barrier_token); - if (accumulation.prepare_if_needed()) { + if constexpr (ScalePromotionInterval != 4) { + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + } + else { + // Always zero out the accumulator for finest granularity tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; } @@ -624,16 +667,16 @@ struct CollectiveMma< // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` and `tCrScaleBViewAsC` if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { ElementBlockScale scale_ab = tCrScaleAViewAsC.data()[0]; - accumulation.scale_if_needed(scale_ab); + scale_if_needed(accumulation, scale_ab); } if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { - accumulation.scale_if_needed(tCrScaleAViewAsC); + scale_if_needed(accumulation, tCrScaleAViewAsC); } if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { - accumulation.scale_if_needed(tCrScaleBViewAsC); + scale_if_needed(accumulation, tCrScaleBViewAsC); } if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { - accumulation.scale_if_needed(tCrScaleAViewAsC, tCrScaleBViewAsC); + scale_if_needed(accumulation, tCrScaleAViewAsC, tCrScaleBViewAsC); } ++smem_pipe_read; @@ -677,7 +720,13 @@ struct CollectiveMma< } } - if (accumulation.prepare_if_needed()) { + if constexpr (ScalePromotionInterval != 4) { + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + } + else { + // Always zero out the accumulator for finest granularity tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; } @@ -699,16 +748,16 @@ struct CollectiveMma< // Block scale the accumulators with reg tensor `tCrScaleAViewAsC` and `tCrScaleBViewAsC` if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { ElementBlockScale scale_ab = tCrScaleAViewAsC.data()[0]; - accumulation.scale_if_needed(scale_ab); + scale_if_needed(accumulation, scale_ab); } if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { - accumulation.scale_if_needed(tCrScaleAViewAsC); + scale_if_needed(accumulation, tCrScaleAViewAsC); } if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { - accumulation.scale_if_needed(tCrScaleBViewAsC); + scale_if_needed(accumulation, tCrScaleBViewAsC); } if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { - accumulation.scale_if_needed(tCrScaleAViewAsC, tCrScaleBViewAsC); + scale_if_needed(accumulation, tCrScaleAViewAsC, tCrScaleBViewAsC); } pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it @@ -718,18 +767,21 @@ struct CollectiveMma< ++smem_pipe_release; } - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { - ElementBlockScale scale_ab = tCrScaleAViewAsC.data()[0]; - accumulation.scale_residue_if_needed(scale_ab); - } - if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { - accumulation.scale_residue_if_needed(tCrScaleAViewAsC); - } - if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { - accumulation.scale_residue_if_needed(tCrScaleBViewAsC); - } - if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { - accumulation.scale_residue_if_needed(tCrScaleAViewAsC, tCrScaleBViewAsC); + if constexpr (ScalePromotionInterval != 4) { + // residues only exists when granularity is not the finnest + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile == 1) { + ElementBlockScale scale_ab = tCrScaleAViewAsC.data()[0]; + accumulation.scale_residue_if_needed(scale_ab); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile == 1) { + accumulation.scale_residue_if_needed(tCrScaleAViewAsC); + } + if constexpr (ScaleMsPerTile == 1 && ScaleNsPerTile > 1) { + accumulation.scale_residue_if_needed(tCrScaleBViewAsC); + } + if constexpr (ScaleMsPerTile > 1 && ScaleNsPerTile > 1) { + accumulation.scale_residue_if_needed(tCrScaleAViewAsC, tCrScaleBViewAsC); + } } warpgroup_fence_operand(accumulation()); diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index 0936eb251..e1ce0def6 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -120,10 +120,38 @@ template< // `ScaleGranularityM`/`ScaleGranularityN` specifies scaling granularity along M/N, while zero-value // `ScaleGranularityM`/`ScaleGranularityN` indicates that scaling granularity is // `size<0>(TileShape_MNK{})`/`size<1>(TileShape_MNK{})` along M/N. - int ScaleGranularityM = 0, - int ScaleGranularityN = 0 + int ScaleGranularityM_ = 0, + int ScaleGranularityN_ = 0, + // `ScalePromotionInterval` specifies the interval to promote the accumulator for scaling + // It is required to be a multiple of 4 and specified in terms of number of MMA instructions + // in the reduction dimension. i.e for FP8 kernels, it is + // ScalePromotionInterval * MMA_K = ScalePromotionInterval * 32 = 128 elements in K by default + int ScalePromotionInterval_ = 4 + > -struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum: KernelTmaWarpSpecializedCooperative { }; +struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum: KernelTmaWarpSpecializedCooperative { + constexpr static int ScaleGranularityM = ScaleGranularityM_; + constexpr static int ScaleGranularityN = ScaleGranularityN_; + constexpr static int ScalePromotionInterval = ScalePromotionInterval_; +}; + +template< + // `ScaleGranularityM`/`ScaleGranularityN` specifies scaling granularity along M/N, while zero-value + // `ScaleGranularityM`/`ScaleGranularityN` indicates that scaling granularity is + // `size<0>(TileShape_MNK{})`/`size<1>(TileShape_MNK{})` along M/N. + int ScaleGranularityM_, + int ScaleGranularityN_, + // `ScalePromotionInterval` specifies the interval to promote the accumulator for scaling + // It is required to be a multiple of 4 and specified in terms of number of MMA instructions + // in the reduction dimension. i.e for FP8 kernels, it is + // ScalePromotionInterval * MMA_K = ScalePromotionInterval * 32 = 128 elements in K by default + int ScalePromotionInterval_ = 4 +> +struct KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum: KernelPtrArrayTmaWarpSpecializedCooperative { + constexpr static int ScaleGranularityM = ScaleGranularityM_; + constexpr static int ScaleGranularityN = ScaleGranularityN_; + constexpr static int ScalePromotionInterval = ScalePromotionInterval_; +}; // Policies to opt into mixed type GEMMs struct KernelTmaWarpSpecializedMixedInput : KernelTmaWarpSpecialized { }; @@ -310,12 +338,17 @@ template< // `ScaleGranularityM`/`ScaleGranularityN` indicates that scaling granularity is // `size<0>(TileShape_MNK{})`/`size<1>(TileShape_MNK{})` along M/N. int ScaleGranularityM = 0, - int ScaleGranularityN = 0 + int ScaleGranularityN = 0, + // `ScalePromotionInterval` specifies the interval to promote the accumulator for scaling + // It is required to be a multiple of 4 and specified in terms of number of MMA instructions + // in the reduction dimension. i.e for FP8 kernels, it is + // ScalePromotionInterval * MMA_K = ScalePromotionInterval * 32 = 128 elements in K by default + int ScalePromotionInterval = 4 > struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8 : MainloopSm90TmaGmmaWarpSpecialized { static_assert( - cute::is_same_v>, + cute::is_same_v>, "KernelSchedule must be one of the warp specialized policies"); }; @@ -327,6 +360,7 @@ template< > struct MainloopSm90ArrayTmaGmmaWarpSpecialized { constexpr static int Stages = Stages_; + constexpr static int PipelineAsyncMmaStages = 1; using ClusterShape = ClusterShape_; using ArchTag = arch::Sm90; using Schedule = KernelSchedule; @@ -391,6 +425,26 @@ struct MainloopSm90ArrayTmaGmmaWarpSpecializedMixedInput { "KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative policies"); }; +// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule +// For FP8 kernels with Block Scaling +template< + int Stages_, + class ClusterShape_ = Shape<_1,_1,_1>, + class KernelSchedule = KernelPtrArrayTmaWarpSpecializedCooperative, + // `ScaleGranularityM`/`ScaleGranularityN` specifies scaling granularity along M/N, while zero-value + // `ScaleGranularityM`/`ScaleGranularityN` indicates that scaling granularity is + // `size<0>(TileShape_MNK{})`/`size<1>(TileShape_MNK{})` along M/N. + int ScaleGranularityM = 0, + int ScaleGranularityN = 0, + int ScalePromotionInterval = 4 +> +struct MainloopSm90ArrayTmaGmmaWarpSpecializedBlockScaling + : MainloopSm90ArrayTmaGmmaWarpSpecialized { + static_assert( + cute::is_same_v>, + "KernelSchedule must be one of the warp specialized policies"); +}; + template< int SchedulerPipelineStageCount_, @@ -411,6 +465,14 @@ struct KernelTmaWarpSpecializedBlockScaledSm100 final { static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; }; +template< + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_ +> +struct KernelTmaWarpSpecializedMmaTransformSm100 final { + static constexpr int SchedulerPipelineStageCount = SchedulerPipelineStageCount_; + static constexpr int AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_; +}; // InputTransform GEMM @@ -484,6 +546,13 @@ struct KernelScheduleSm100PtrArrayDenseGemm : KernelScheduleSm100DenseGemm {}; struct KernelPtrArrayTmaWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100PtrArrayDenseGemm {}; struct KernelPtrArrayTmaWarpSpecialized2SmSm100 final : KernelSchedule2Sm, KernelScheduleSm100PtrArrayDenseGemm {}; + +/////////////////////////////////////////////////////////////////////////////////////////////////////// +// SM100 Blockwise GEMM Dispatch Policies +/////////////////////////////////////////////////////////////////////////////////////////////////////// +struct KernelScheduleSm100Blockwise : KernelScheduleSm100 {}; +struct KernelTmaWarpSpecializedBlockwise1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100Blockwise {}; + /////////////////////////////////////////////////////////////////////////////////////////////////////// // SM100 Planar Complex GEMM Dispatch Policies /////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -530,6 +599,9 @@ struct KernelTmaWarpSpecialized1SmMxf4Sm100 final : KernelSchedule1Sm, KernelSch struct KernelTmaWarpSpecialized2SmMxf4Sm100 final : KernelSchedule2Sm, KernelScheduleMxNvf4Sm100 { }; struct KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 final : KernelSchedule1Sm, KernelScheduleMxf8f6f4Sm100 { }; struct KernelTmaWarpSpecialized2SmMxf8f6f4Sm100 final : KernelSchedule2Sm, KernelScheduleMxf8f6f4Sm100 { }; +/////////////////////////////////////////////////////////////////////////////////////////////////////// +// SM100 BlockScaled Ptr Array Dense GEMM Dispatch Policies +/////////////////////////////////////////////////////////////////////////////////////////////////////// // BlockScaled Dense GEMM + (Ptr Array or Group GEMM) struct KernelSchedulePtrArrayBlockScaledGemmSm100 : KernelScheduleBlockScaledGemmSm100 {}; struct KernelSchedulePtrArrayMxNvf4Sm100 : KernelSchedulePtrArrayBlockScaledGemmSm100 {}; @@ -544,8 +616,6 @@ struct KernelPtrArrayTmaWarpSpecialized2SmMxf4Sm100 final : KernelSchedule2Sm, K struct KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100 final : KernelSchedule1Sm, KernelSchedulePtrArrayMxf8f6f4Sm100 { }; struct KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100 final : KernelSchedule2Sm, KernelSchedulePtrArrayMxf8f6f4Sm100 { }; - - // n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule template< int Stages_, @@ -561,7 +631,20 @@ struct MainloopSm100TmaUmmaWarpSpecialized { constexpr static bool IsOverlappingAccum = false; }; - +// n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule +template< + int Stages_, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ClusterShape_ = Shape<_1,_1,_1> +> +struct MainloopSm100TmaUmmaWarpSpecializedBlockwiseScaling { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm100; + using Schedule = KernelTmaWarpSpecializedMmaTransformSm100; + constexpr static bool IsOverlappingAccum = false; +}; // n-buffer in smem, pipelined with Blackwell UMMA and TMA, Warp specialized dynamic schedule template< diff --git a/include/cutlass/gemm/kernel/gemm_universal.hpp b/include/cutlass/gemm/kernel/gemm_universal.hpp index 77b2e1ead..a6c2f43cd 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.hpp +++ b/include/cutlass/gemm/kernel/gemm_universal.hpp @@ -65,6 +65,7 @@ struct IsCutlass3ArrayKernel +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileSchedulerTag_, + cute::enable_if_t< + cutlass::detail::is_kernel_tag_of_v>> { +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 100); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using EpilogueTile = typename CollectiveEpilogue::EpilogueTile; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + static constexpr bool IsComplex = CollectiveEpilogue::NumAccumulatorMtxs == 2; + + // CLC pipeline depth + // determines how many waves (stages-1) a warp can race ahead + static constexpr uint32_t SchedulerPipelineStageCount = DispatchPolicy::Schedule::SchedulerPipelineStageCount; + static constexpr uint32_t AccumulatorPipelineStageCount = DispatchPolicy::Schedule::AccumulatorPipelineStageCount; + static constexpr bool IsOverlappingAccum = DispatchPolicy::IsOverlappingAccum; + + static_assert(!IsOverlappingAccum, "Does not support overlapping accumulator"); + + // TileID scheduler + // Get Blk and Scheduling tile shapes + using AtomThrShapeMNK = typename CollectiveMainloop::AtomThrShapeMNK; + using CtaShape_MNK = typename CollectiveMainloop::CtaShape_MNK; + using TileSchedulerTag = TileSchedulerTag_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileSchedulerTag, ArchTag, CtaShape_MNK, ClusterShape, SchedulerPipelineStageCount>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr bool IsSchedDynamicPersistent = TileScheduler::IsDynamicPersistent; + + static constexpr bool IsDynamicCluster = not cute::is_static_v; + static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; + + // Warp specialization thread count per threadblock + static constexpr uint32_t NumSchedThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMMAThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumMainloopLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueLoadThreads = NumThreadsPerWarp; // 1 warp + static constexpr uint32_t NumEpilogueThreads = CollectiveEpilogue::ThreadCount; + static constexpr uint32_t NumEpilogueWarps = NumEpilogueThreads / NumThreadsPerWarp; + + static constexpr uint32_t MaxThreadsPerBlock = NumSchedThreads + + NumMainloopLoadThreads + NumMMAThreads + + NumEpilogueLoadThreads + NumEpilogueThreads; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + static constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_load_pipe_increment(CtaShape_MNK{}); + + // Fixup performed for split-/stream-K is done across warps in different CTAs + // at epilogue subtile granularity. Thus, there must be one barrier per sub-tile per + // epilogue warp. + static constexpr uint32_t NumFixupBarriers = 1; + static constexpr uint32_t CLCResponseSize = sizeof(typename TileScheduler::CLCResponse); + + // Pipeline and pipeline state types + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + using MainloopPipelineState = typename CollectiveMainloop::MainloopPipelineState; + + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + using EpiLoadPipelineState = typename CollectiveEpilogue::LoadPipelineState; + + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + using EpiStorePipelineState = typename CollectiveEpilogue::StorePipelineState; + + using LoadOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; + + using Mma2TransformPipeline = typename CollectiveMainloop::Mma2TransformPipeline; + using Mma2TransformPipelineState = typename Mma2TransformPipeline::PipelineState; + + using Load2TransformPipeline = typename CollectiveMainloop::Load2TransformPipeline; + using Load2TransformPipelineState = typename Load2TransformPipeline::PipelineState; + + using CLCPipeline = cutlass::PipelineCLCFetchAsync; + using CLCPipelineState = typename CLCPipeline::PipelineState; + + using CLCThrottlePipeline = cutlass::PipelineAsync; + using CLCThrottlePipelineState = typename CLCThrottlePipeline::PipelineState; + + using TmemAllocator = cute::conditional_t(typename TiledMma::ThrLayoutVMNK{})) == 1, + cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>; + + static constexpr uint32_t GenericRegisterRequirement = 104; + static constexpr uint32_t AccumRegisterRequirement = 256; + + // Kernel level shared memory storage + struct SharedStorage { + // Barriers should be allocated in lower 8KB of SMEM for SM100 + struct PipelineStorage : cute::aligned_struct<16, _1> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using Load2TransformPipelineStorage = typename CollectiveMainloop::Load2TransformPipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using LoadOrderBarrierStorage = typename LoadOrderBarrier::SharedStorage; + using CLCPipelineStorage = typename CLCPipeline::SharedStorage; + using Mma2TransformPipelineStorage = typename CollectiveMainloop::Mma2TransformPipelineStorage; + using CLCThrottlePipelineStorage = typename CLCThrottlePipeline::SharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) Load2TransformPipelineStorage load2transform; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) LoadOrderBarrierStorage load_order; + alignas(16) CLCPipelineStorage clc; + alignas(16) Mma2TransformPipelineStorage mma2transform; + alignas(16) CLCThrottlePipelineStorage clc_throttle; + alignas(16) arch::ClusterBarrier tmem_dealloc; + alignas(16) arch::ClusterBarrier epilogue_throttle; + } pipelines; + + alignas(16) typename TileScheduler::CLCResponse clc_response[SchedulerPipelineStageCount]; + uint32_t tmem_base_ptr; + + struct TensorStorage : cute::aligned_struct<128, _1> { + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + + EpilogueTensorStorage epilogue; + MainloopTensorStorage mainloop; + } tensors; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "SMEM usage exceeded capacity."); + + // Host facing host arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel device entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + TileSchedulerParams scheduler{}; + KernelHardwareInfo hw_info{}; + }; + + enum class WarpCategory : int32_t { + MMA = 0, + Sched = 1, + MainloopLoad = 2, + EpilogueLoad = 3, + Epilogue = 4 + }; + + struct IsParticipant { + uint32_t mma = false; + uint32_t sched = false; + uint32_t main_load = false; + uint32_t epi_load = false; + uint32_t epilogue = false; + }; + + // + // Methods + // + + // Convert to underlying arguments. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + auto problem_shape = args.problem_shape; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count != 0) { + CUTLASS_TRACE_HOST(" WARNING: SM100 tile scheduler does not allow for user specified SM counts.\n" + " To restrict a kernel's resource usage, consider using CUDA driver APIs instead (green contexts)."); + } + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + // Epilogue + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = nullptr; + + // Tile scheduler + void* scheduler_workspace = workspace_ptr + workspace_offset; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + return { + args.mode, + args.problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace, args.hw_info), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), + TileScheduler::to_underlying_arguments( + problem_shape_MNKL, TileShape{}, AtomThrShapeMNK{}, ClusterShape{}, + args.hw_info, args.scheduler, scheduler_workspace + ) + ,args.hw_info + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + + if constexpr (IsDynamicCluster) { + static constexpr int MaxClusterSize = 16; + implementable &= size(args.hw_info.cluster_shape) <= MaxClusterSize; + implementable &= size(args.hw_info.cluster_shape_fallback) <= MaxClusterSize; + implementable &= cutlass::detail::preferred_cluster_can_implement(args.hw_info.cluster_shape, args.hw_info.cluster_shape_fallback); + } + + return implementable; + } + + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_size = 0; + + // Epilogue + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + // Tile scheduler + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + // Epilogue + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + // Tile scheduler + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs, cuda_adapter); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumFixupBarriers, NumEpilogueSubTiles, CollectiveEpilogue::NumAccumulatorMtxs); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + // NOTE cluster_shape here is the major cluster shape, not fallback one + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}, params.hw_info.cluster_shape); + + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + return TileScheduler::get_grid_shape( + params.scheduler, + problem_shape_MNKL, + TileShape{}, + AtomThrShapeMNK{}, + cluster_shape, + params.hw_info); + } + + static constexpr + dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator() (Params const& params, char* smem_buf) { + + using namespace cute; + using X = Underscore; + + // Separate out problem shape for convenience + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + auto [M,N,K,L] = problem_shape_MNKL; + + // Account for more than one epilogue warp + int warp_idx = canonical_warp_idx_sync(); + WarpCategory warp_category = warp_idx < static_cast(WarpCategory::Epilogue) ? WarpCategory(warp_idx) + : WarpCategory::Epilogue; + + uint32_t lane_predicate = cute::elect_one_sync(); + auto cluster_shape = cutlass::detail::select_cluster_shape(ClusterShape{}); + int cluster_size = size(cluster_shape); + uint32_t cta_rank_in_cluster = cute::block_rank_in_cluster(); + bool is_first_cta_in_cluster = cta_rank_in_cluster == 0; + int cta_coord_v = cta_rank_in_cluster % size<0>(typename TiledMma::AtomThrID{}); + bool is_mma_leader_cta = cta_coord_v == 0; + constexpr bool has_mma_peer_cta = size(AtomThrShapeMNK{}) == 2; + [[maybe_unused]] uint32_t mma_peer_cta_rank = has_mma_peer_cta ? cta_rank_in_cluster ^ 1 : cta_rank_in_cluster; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop(params.mainloop, cluster_shape, cta_rank_in_cluster); + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_category == WarpCategory::Sched) && lane_predicate) { + collective_mainloop.prefetch_tma_descriptors(); + } + if ((warp_category == WarpCategory::EpilogueLoad) && lane_predicate) { + collective_epilogue.prefetch_tma_descriptors(params.epilogue); + } + + // Do we load source tensor C or other aux inputs + bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); + IsParticipant is_participant = { + (warp_category == WarpCategory::MMA), // mma + (warp_category == WarpCategory::Sched) && is_first_cta_in_cluster, // sched + (warp_category == WarpCategory::MainloopLoad), // main_load + (warp_category == WarpCategory::EpilogueLoad) && is_epi_load_needed, // epi_load + (warp_category == WarpCategory::Epilogue) // epilogue + }; + + // Mainloop Load pipeline + typename MainloopPipeline::Params mainloop_pipeline_params; + if (WarpCategory::MainloopLoad == warp_category) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (WarpCategory::MMA == warp_category) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = lane_predicate && is_mma_leader_cta && is_participant.main_load; + mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + mainloop_pipeline_params.initializing_warp = 0; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, + mainloop_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{}); // Delay mask calculation + + typename Load2TransformPipeline::Params load2transform_pipeline_params; + if (WarpCategory::MainloopLoad == warp_category) { + load2transform_pipeline_params.role = Load2TransformPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + load2transform_pipeline_params.role = Load2TransformPipeline::ThreadCategory::Consumer; + } + load2transform_pipeline_params.initializing_warp = 0; + load2transform_pipeline_params.producer_arv_count = CollectiveMainloop::NumLoad2TransformProducerThreadEvents; + load2transform_pipeline_params.consumer_arv_count = NumEpilogueThreads; + + Load2TransformPipeline load2transform_pipeline(shared_storage.pipelines.load2transform, + load2transform_pipeline_params); + + // Epilogue Load pipeline + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (WarpCategory::EpilogueLoad == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cta_rank_in_cluster; + epi_load_pipeline_params.producer_arv_count = NumEpilogueLoadThreads; + epi_load_pipeline_params.consumer_arv_count = NumEpilogueThreads; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + epi_load_pipeline_params.initializing_warp = 4; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // Load order barrier + typename LoadOrderBarrier::Params load_order_barrier_params; + load_order_barrier_params.group_id = (warp_category == WarpCategory::MainloopLoad) ? 0 : 1; + load_order_barrier_params.group_size = NumMainloopLoadThreads; + load_order_barrier_params.initializing_warp = 5; + LoadOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, load_order_barrier_params); + + // CLC pipeline + typename CLCPipeline::Params clc_pipeline_params; + if (WarpCategory::Sched == warp_category) { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::ProducerConsumer; + } + else { + clc_pipeline_params.role = CLCPipeline::ThreadCategory::Consumer; + } + clc_pipeline_params.producer_blockid = 0; + clc_pipeline_params.producer_arv_count = 1; + clc_pipeline_params.consumer_arv_count = NumSchedThreads + cluster_size * + (NumMainloopLoadThreads + NumEpilogueThreads + NumMMAThreads); + if (is_epi_load_needed) { + clc_pipeline_params.consumer_arv_count += cluster_size * NumEpilogueLoadThreads; + } + clc_pipeline_params.transaction_bytes = CLCResponseSize; + clc_pipeline_params.initializing_warp = 1; + CLCPipeline clc_pipeline(shared_storage.pipelines.clc, clc_pipeline_params, cluster_shape); + + // Mainloop-Epilogue pipeline + typename Mma2TransformPipeline::Params mma2transform_pipeline_params; + if (WarpCategory::MMA == warp_category) { + mma2transform_pipeline_params.role = Mma2TransformPipeline::ThreadCategory::Producer; + } + if (WarpCategory::Epilogue == warp_category) { + mma2transform_pipeline_params.role = Mma2TransformPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + mma2transform_pipeline_params.producer_arv_count = 1; + mma2transform_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueThreads; + mma2transform_pipeline_params.initializing_warp = 2; + Mma2TransformPipeline mma2transform_pipeline(shared_storage.pipelines.mma2transform, + mma2transform_pipeline_params, + cluster_shape); + + // CLC throttle pipeline + typename CLCThrottlePipeline::Params clc_throttle_pipeline_params; + if (WarpCategory::MainloopLoad == warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Producer; + } + if (WarpCategory::Sched == warp_category) { + clc_throttle_pipeline_params.role = CLCThrottlePipeline::ThreadCategory::Consumer; + } + clc_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads; + clc_throttle_pipeline_params.consumer_arv_count = NumSchedThreads; + clc_throttle_pipeline_params.dst_blockid = 0; + clc_throttle_pipeline_params.initializing_warp = 3; + CLCThrottlePipeline clc_throttle_pipeline(shared_storage.pipelines.clc_throttle, clc_throttle_pipeline_params); + CLCThrottlePipelineState clc_pipe_throttle_consumer_state; + CLCThrottlePipelineState clc_pipe_throttle_producer_state = cutlass::make_producer_start_state(); + + // Tmem allocator + TmemAllocator tmem_allocator{}; + + // Sync allocation status between MMA and epilogue warps within CTA + arch::NamedBarrier tmem_allocation_result_barrier(NumMMAThreads + NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + // Sync deallocation status between MMA warps of peer CTAs + arch::ClusterBarrier& tmem_deallocation_result_barrier = shared_storage.pipelines.tmem_dealloc; + [[maybe_unused]] uint32_t dealloc_barrier_phase = 0; + + if (WarpCategory::MMA == warp_category && has_mma_peer_cta && lane_predicate) { + tmem_deallocation_result_barrier.init(NumMMAThreads); + } + + + // Initialize smem barrier for prologue throttling. Epilogue warps are stalled until the prologue finishes. + arch::ClusterBarrier& epilogue_throttle_barrier = shared_storage.pipelines.epilogue_throttle; + if (WarpCategory::MMA == warp_category && lane_predicate) { + epilogue_throttle_barrier.init( NumMMAThreads + + (is_first_cta_in_cluster ? NumSchedThreads : 0) + + NumMainloopLoadThreads + + (is_epi_load_needed ? NumEpilogueLoadThreads : 0)); + } + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer threadblocks in the cluster + pipeline_init_arrive_relaxed(cluster_size); + + auto load_inputs = collective_mainloop.load_init( + problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop); + + MainloopPipelineState mainloop_pipe_consumer_state; + MainloopPipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + + EpiLoadPipelineState epi_load_pipe_consumer_state; + EpiLoadPipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + + // epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + EpiStorePipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + CLCPipelineState clc_pipe_consumer_state; + CLCPipelineState clc_pipe_producer_state = cutlass::make_producer_start_state(); + + Mma2TransformPipelineState mma2transform_pipe_consumer_state; + Mma2TransformPipelineState mma2transform_pipe_producer_state = cutlass::make_producer_start_state(); + + Load2TransformPipelineState load2transform_pipe_consumer_state; + Load2TransformPipelineState load2transform_pipe_producer_state = cutlass::make_producer_start_state(); + + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + + // Calculate mask after cluster barrier arrival + mainloop_pipeline.init_masks(cluster_shape, block_id_in_cluster); + mma2transform_pipeline.init_masks(cluster_shape, block_id_in_cluster); + + // TileID scheduler + TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, block_id_in_cluster); + typename TileScheduler::WorkTileInfo work_tile_info = scheduler.initial_work_tile_info(cluster_shape); + auto cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + // + // TMEM "Allocation" + // + auto tmem_storage = collective_mainloop.template init_tmem_tensors(EpilogueTile{}); + + pipeline_init_wait(cluster_size); + + if (is_participant.main_load) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_arrive = is_epi_load_needed; + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + bool requires_clc_query = true; + + auto pipelines = cute::make_tuple(mainloop_pipeline, load2transform_pipeline); + auto states = cute::make_tuple(mainloop_pipe_producer_state, load2transform_pipe_producer_state); + + do { + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto k_tile_iter = scheduler.get_k_tile_iterator(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}, load_inputs.k_tiles); + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + auto k_tile_prologue = min(MainloopPipeline::Stages, k_tile_count); + + if constexpr (IsSchedDynamicPersistent) { + if (is_first_cta_in_cluster && requires_clc_query) { + clc_throttle_pipeline.producer_acquire(clc_pipe_throttle_producer_state); + clc_throttle_pipeline.producer_commit(clc_pipe_throttle_producer_state); + ++clc_pipe_throttle_producer_state; + } + } + + // Start mainloop prologue loads, arrive on the epilogue residual load barrier, resume mainloop loads + auto [mainloop_producer_state_next, load2transform_producer_state_next, k_tile_iter_next] = collective_mainloop.load( + mainloop_pipeline, + load2transform_pipeline, + mainloop_pipe_producer_state, + load2transform_pipe_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter, k_tile_prologue + ); + mainloop_pipe_producer_state = mainloop_producer_state_next; + load2transform_pipe_producer_state = load2transform_producer_state_next; + + if (do_load_order_arrive) { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + auto [mainloop_producer_state_next_, load2transform_producer_state_next_, unused_] = collective_mainloop.load( + mainloop_pipeline, + load2transform_pipeline, + mainloop_pipe_producer_state, + load2transform_pipe_producer_state, + load_inputs, + cta_coord_mnkl, + k_tile_iter_next, k_tile_count - k_tile_prologue + ); + mainloop_pipe_producer_state = mainloop_producer_state_next_; + load2transform_pipe_producer_state = load2transform_producer_state_next_; + + // Sync warp to prevent non-participating threads entering next wave early + __syncwarp(); + + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + } while (work_tile_info.is_valid()); + + collective_mainloop.load_tail( + mainloop_pipeline, + load2transform_pipeline, + mainloop_pipe_producer_state, + load2transform_pipe_producer_state + ); + + } + + else if (is_participant.sched) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + + if constexpr (IsSchedDynamicPersistent) { + + // Whether a new CLC query must be performed. + // See comment below where this variable is updated for a description of + // why this variable is needed. + bool requires_clc_query = true; + + do { + if (requires_clc_query) { + // Throttle CLC query to mitigate workload imbalance caused by skews among persistent workers. + clc_throttle_pipeline.consumer_wait(clc_pipe_throttle_consumer_state); + clc_throttle_pipeline.consumer_release(clc_pipe_throttle_consumer_state); + ++clc_pipe_throttle_consumer_state; + + // Query next clcID and update producer state + clc_pipe_producer_state = scheduler.advance_to_next_work(clc_pipeline, clc_pipe_producer_state); + } + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + // Only perform a new CLC query if we consumed a new CLC query result in + // `fetch_next_work`. An example of a case in which CLC `fetch_next_work` does + // not consume a new CLC query response is when processing stream-K units. + // The current stream-K scheduler uses single WorkTileInfo to track multiple + // (potentially-partial) tiles to be computed via stream-K. In this case, + // `fetch_next_work` simply performs in-place updates on the existing WorkTileInfo, + // rather than consuming a CLC query response. + requires_clc_query = increment_pipe; + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + work_tile_info = next_work_tile_info; + } while (work_tile_info.is_valid()); + clc_pipeline.producer_tail(clc_pipe_producer_state); + + } + } + + else if (is_participant.mma) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + + // Tmem allocation sequence + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + collective_mainloop.set_tmem_offsets(tmem_storage, tmem_base_ptr); + + auto mma_inputs = collective_mainloop.mma_init( + tmem_storage, + shared_storage.tensors.mainloop); + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + + do { + + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + if (is_mma_leader_cta) { + auto [mainloop_pipe_consumer_state_, mma2transform_pipe_producer_state_] = collective_mainloop.mma( + cute::make_tuple(mainloop_pipeline, mma2transform_pipeline), + cute::make_tuple(mainloop_pipe_consumer_state, mma2transform_pipe_producer_state), + tmem_storage, + mma_inputs, + cta_coord_mnkl, + k_tile_count + ); + mainloop_pipe_consumer_state = mainloop_pipe_consumer_state_; + mma2transform_pipe_producer_state = mma2transform_pipe_producer_state_; + } + + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + + // Hint on an early release of global memory resources. + // The timing of calling this function only influences performance, + // not functional correctness. + cutlass::arch::launch_dependent_grids(); + + // Release the right to allocate before deallocations so that the next CTA can rasterize + tmem_allocator.release_allocation_lock(); + + // Leader MMA waits for leader + peer epilogues to release stage + if (is_mma_leader_cta) { + mma2transform_pipeline.producer_tail(mma2transform_pipe_producer_state); + } + // Signal to peer MMA that entire tmem allocation can be deallocated + if constexpr (has_mma_peer_cta) { + // Leader does wait + arrive, follower does arrive + wait + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, not is_mma_leader_cta); + tmem_deallocation_result_barrier.wait(dealloc_barrier_phase); + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, is_mma_leader_cta); + } + + // Free entire tmem allocation + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + else if (is_participant.epi_load) { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + bool do_load_order_wait = true; + bool do_tail_load = false; + int current_wave = 0; + + // Signal the epilogue warps to proceed once the prologue is complete + epilogue_throttle_barrier.arrive(); + + do { + bool compute_epilogue = TileScheduler::compute_epilogue(work_tile_info, params.scheduler); + + // Get current work tile and fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + work_tile_info = next_work_tile_info; + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + if (compute_epilogue) { + if (do_load_order_wait) { + load_order_barrier.wait(); + do_load_order_wait = false; + } + + bool reverse_epi_n = IsOverlappingAccum && (current_wave % 2 == 0); + epi_load_pipe_producer_state = collective_epilogue.template load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + shared_storage.tensors.epilogue, + reverse_epi_n + ); + + do_tail_load = true; + } + current_wave++; + + // Calculate the cta coordinates of the next work tile + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + } while (work_tile_info.is_valid()); + + // Only perform a tail load if one of the work units processed performed + // an epilogue load. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_load) { + collective_epilogue.load_tail( + epi_load_pipeline, epi_load_pipe_producer_state, + epi_store_pipeline, epi_store_pipe_producer_state); + } + } + + else if (is_participant.epilogue) { + // Register reconfiguration + arch::warpgroup_reg_alloc(); + + // Throttle the epilogue warps to improve prologue performance + static constexpr int epilogue_throttle_phase_bit = 0; + epilogue_throttle_barrier.wait(epilogue_throttle_phase_bit); + + // Wait for tmem allocate here + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + collective_mainloop.set_tmem_offsets(tmem_storage, tmem_base_ptr); + + auto transform_inputs = collective_mainloop.transform_init( + problem_shape_MNKL, + shared_storage.tensors.mainloop + ); + + auto pipelines = cute::make_tuple(mma2transform_pipeline, load2transform_pipeline); + auto states = cute::make_tuple(mma2transform_pipe_consumer_state, load2transform_pipe_consumer_state); + bool do_tail_store = false; + do { + + auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, CtaShape_MNK{}); + + // Fetch next work tile + auto [next_work_tile_info, increment_pipe] = scheduler.fetch_next_work( + work_tile_info, + clc_pipeline, + clc_pipe_consumer_state + ); + + if (increment_pipe) { + ++clc_pipe_consumer_state; + } + + auto [accum, tiled_t2r, next_state] = collective_mainloop.transform( + pipelines, + states, + tmem_storage, + transform_inputs, + cta_coord_mnkl, + typename CollectiveEpilogue::CopyOpT2R{}, + typename CollectiveEpilogue::EpilogueTile{}, + k_tile_count + ); + + states = next_state; + + auto fixup_next_state = scheduler.template fixup( + TiledMma{}, + work_tile_info, + accum, + get<0>(pipelines), + get<0>(next_state), + typename CollectiveEpilogue::CopyOpT2R{} + ); + + get<0>(states) = fixup_next_state; + + // + // Epilogue and write to gD + // + if (scheduler.compute_epilogue(work_tile_info)) { + auto [load_state_next, store_state_next] = collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + problem_shape_MNKL, + CtaShape_MNK{}, + cta_coord_mnkl, + TileShape{}, + TiledMma{}, + accum, + shared_storage.tensors.epilogue, + tiled_t2r + ); + epi_load_pipe_consumer_state = load_state_next; + epi_store_pipe_producer_state = store_state_next; + do_tail_store = true; + } + work_tile_info = next_work_tile_info; + cta_coord_mnkl = scheduler.work_tile_to_cta_coord(work_tile_info); + + } while (work_tile_info.is_valid()); + + // Only perform a tail store if one of the work units processed performed + // an epilogue. An example of a case in which a tail load should not be + // performed is in split-K if a cluster is only assigned non-final splits (for which + // the cluster does not compute the epilogue). + if (do_tail_store) { + collective_epilogue.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + CtaShape_MNK{}); + } + } else { + // Register reconfiguration + arch::warpgroup_reg_dealloc(); + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp index 7c375747c..205233043 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp @@ -128,10 +128,11 @@ public: using TileSchedulerParams = typename TileScheduler::Params; static constexpr uint32_t NumLoadWarpGroups = 1; - static constexpr uint32_t NumMmaThreads = CUTE_STATIC_V(size(TiledMma{})); + static constexpr uint32_t NumMmaThreads = size(TiledMma{}); static constexpr uint32_t NumMmaWarpGroups = NumMmaThreads / NumThreadsPerWarpGroup; static constexpr uint32_t MaxThreadsPerBlock = NumMmaThreads + (NumLoadWarpGroups * NumThreadsPerWarpGroup); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + static constexpr uint32_t NumProducerThreads = CollectiveMainloop::NumProducerThreadEvents; /// Register requirement for Load and Math WGs static constexpr uint32_t LoadRegisterRequirement = 40; @@ -434,7 +435,8 @@ public: mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; } mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; - mainloop_pipeline_params.num_consumers = size(TiledMma{}); + mainloop_pipeline_params.num_consumers = NumMmaThreads; + mainloop_pipeline_params.num_producers = NumProducerThreads; mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes; MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); @@ -575,6 +577,7 @@ public: auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); if (did_batch_change) { + load_inputs = collective_mainloop.tensors_perform_update(load_inputs, params.mainloop, problem_shape_MNKL, curr_batch); collective_mainloop.tensormaps_fence_acquire(input_tensormaps); } diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp index 5bd5196f7..3515f4d27 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp @@ -131,6 +131,7 @@ public: static constexpr uint32_t NumMmaWarpGroups = 2; static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumMmaWarpGroups * NumThreadsPerWarpGroup); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + static constexpr uint32_t NumProducerThreads = CollectiveMainloop::NumProducerThreadEvents; /// Register requirement for Load and Math WGs static constexpr uint32_t LoadRegisterRequirement = 40; @@ -443,6 +444,7 @@ public: } mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup; + mainloop_pipeline_params.num_producers = NumProducerThreads; mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes; MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); @@ -607,6 +609,7 @@ public: auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); if (did_batch_change) { + load_inputs = collective_mainloop.tensors_perform_update(load_inputs, params.mainloop, problem_shape_MNKL, curr_batch); collective_mainloop.tensormaps_fence_acquire(input_tensormaps); }