diff --git a/CHANGELOG.md b/CHANGELOG.md index 3df2d095c..7a837307f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,40 @@ # CUTLASS 4.x +## [4.6.0](https://github.com/NVIDIA/cutlass/tree/main) (2026-06-11) + +### CuTe DSL +* New features + - Supported AoT cross-compilation for aarch64‑linux‑gnu + - Support for two launch attributes: launch completion events (cudaLaunchAttributeLaunchCompletionEvent), for recording an event once all thread blocks have begun executing, and launch programatic events (cudaLaunchAttributeProgrammaticEvent), for PDL event-based synchronization + - Supported auto calculating per-kernel shared memory carveout preference, or use new launch option `preferred_smem_carveout` to set manually. + - Auto-deduced smem size for launching kernels + - Launch config `smem` now defaults to `None` for auto-calculating kernel shared memory usage, which is recommended unless manual control is required. + - Warnings will be raised when the manually set shared memory size is insufficient or exceeds the GPU maximum. + - The default shared memory usage calculation aligns with CUDA C++ static shared memory behavior, i.e. summing all allocations additively. + - An additional launch option `smem_merge_branch_allocs` is provided to merge shared memory allocations across mutually exclusive code branches, which is recommended for inlined mega-kernels to reduce total footprint. + +* Bug fixing and improvements + - Improvements on linter support with more type ignores cleaned up + - Improvements on tvm-ffi CUDA runtime error diagnostics + - Improvements on dataclass support for TVM-FFI + - Fixed a regression on compilation time + - Enhancement on compile time checks to reject mis-aligned smem operand for TMA + - Long-deprecated API clean-up, including: + - cute.core.ThrMma, please use cute.ThrMma instead + - cute.core.ThrCopy, please use cute.ThrCopy instead + - cute.make_fragment, please use cute.make_rmem_tensor instead + +### CUTLASS C++ +* Add [example 113](https://github.com/NVIDIA/cutlass/tree/main/examples/113_hopper_gemm_activation_fusion) for Hopper GEMM with activation fusion. + - Supports standard and gated activations (e.g., SiLu) with fp8 and fp16 inputs. + - Covers both regular GEMM and grouped GEMM variants. +* Improve SM90 grouped/ptr-array GEMM with EVT support. + - Adds the EVT (Epilogue Visitor Tree) plumbing required to do activation, bias, and auxiliary-tensor fusion inside SM90 grouped and ptr-array GEMM kernels. +* Fix `DescriptorIterator::operator+` in `mma_traits_sm100.hpp` to use 32-bit arithmetic on CUDA toolkit version <= 13.3, preserving the high half of the smem descriptor. +* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! +* Optimal code generation with CUDA toolkit versions 13.3. + ## [4.5.2](https://github.com/NVIDIA/cutlass/releases/tag/v4.5.2) (2026-05-22) ### CuTe DSL diff --git a/README.md b/README.md index fdbb6f3db..75e590e7f 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,9 @@ ![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") # Overview -# CUTLASS 4.5.2 +# CUTLASS 4.6.0 -_CUTLASS 4.5.2 - May 2026_ +_CUTLASS 4.6.0 - June 2026_ CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM) and related computations at all levels and scales within CUDA. It incorporates strategies for @@ -37,86 +37,43 @@ We believe it will become an indispensable tool for students, researchers, and p engineers alike — flattening the learning curve of GPU programming, rapidly prototyping kernel designs, and bringing optimized solutions into production. -CuTe DSL is currently in public beta and will graduate out of beta by end of summer 2025. +CuTe DSL is currently in public beta and will graduate out of beta by end of summer 2026. To get started quickly - please refer : - [CUTLASS C++ Quick Start Guide](https://docs.nvidia.com/cutlass/latest/media/docs/cpp/quickstart.html). - [CuTe DSL Quick Start Guide](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/quick_start.html). -# What's New in CUTLASS 4.5 +# What's New in CUTLASS 4.6 ## CuTe DSL * New features - - New Block API `block_copy()` to simplify TMA and S2T copy. Users can ignore detail about multicast and 2CTA partition for TMA by `block_copy()` and need not to invoke `tma_partition()`. And users can remove bulk of S2T initialization to simplify S2T copy. - - MXF8F6F4 mixed precision support - - BlockScaled MMA now supports MXF8*MXF4 or MXF8*MXF6 - - Block Scaled MMA for SM120 now works on Spark - - EFC broadcast semantics support - - EFC epilogue functions can now broadcast and remap tensor modes via `C.remap_modes[:, 0, 1]` subscript syntax (where `:` marks a broadcast dimension and integers select source mode indices). Covers scalar broadcast, row/column broadcast, and arbitrary mode permutations (e.g. transpose). The PyTorch reference evaluator mirrors the same transformations. - - Initial linter support: Improved type hints on CuTe DSL APIs to support static type checkers like MyPy - - dataclasses.dataclass is now supported for JIT compilaton and cute.compile for both plain and tvm-ffi path - - cute.copy now supports user specified loop unrolling - - Python 3.14t is now supported with GIL enabled + - Supported AoT cross-compilation for aarch64?~@~Qlinux?~@~Qgnu + - Support for two launch attributes: launch completion events (cudaLaunchAttributeLaunchCompletionEvent), for recording an event once all thread blocks have begun executing, and launch programatic events (cudaLaunchAttributeProgrammaticEvent), for PDL event-based synchronization + - Supported auto calculating per-kernel shared memory carveout preference, or use new laucnch option `preferred_smem_carveout` to set manually. + - Auto-deduced smem size for launching kernels + - Launch config `smem` now defaults to `None` for auto-calculating kernel shared memory usage, which is recommended unless manual control is required. + - Warnings will be raised when the manually set shared memory size is insufficient or exceeds the GPU maximum. + - The default shared memory usage calculation aligns with CUDA C++ static shared memory behavior, i.e. summing all allocations additively. + - An additional launch option `smem_merge_branch_allocs` is provided to merge shared memory allocations across mutually exclusive code branches, which is recommended for inlined mega-kernels to reduce total footprint. * Bug fixing and improvements - - Improved source code correlation for profiling/debugging - - Fixed an aarch64 segfault issue with tvm-ffi - - Re-organization for CuTe DSL examples/tutorials for better discoverability - - Fixed following issues: - https://github.com/NVIDIA/cutlass/issues/3219 - https://github.com/NVIDIA/cutlass/issues/3218 - https://github.com/NVIDIA/cutlass/issues/3212 - https://github.com/NVIDIA/cutlass/issues/3210 - https://github.com/NVIDIA/cutlass/issues/3208 - https://github.com/NVIDIA/cutlass/issues/3201 - https://github.com/NVIDIA/cutlass/issues/3227 - https://github.com/NVIDIA/cutlass/issues/3240 - https://github.com/NVIDIA/cutlass/issues/3241 - - Fixed Jax int64 stride divisibility issue - - Fixed issues for SM120 blockscaled MMAs - - added missing MXFP8MMAOP and MXF8F6F4MMAOP for sm120. - -* More examples of authorizing peak-performance kernels - - MOE examles - - A new style of grouped-gemm that aligns to torch's grouped_mm and scaled_groued_mm interface. - - Expert-wise tensormap descriptor setup by a cheap helper kernel (~2us) to avoid long latency in tile switching, kernel structure is much more closer to a normal GEMM. - - Compared to torch_210_cu13, very few problem has worse perf in B200. - - mxfp8_2dx3d: avg 1.29 speedup; - - mxfp8_2dx2d: avg 1.41 speedup; - - nvfp4_2dx3d: avg 1.11 speedup; - - nvfp4_2dx2d: avg 1.12 speedup (worst case 0.98) - - bf16_2dx3d: avg 1.15 speedup (worst case 0.98) - - bf16_2dx2d: avg 1.17 speedup (worst case 0.96) - - Note: The perf is measured from torch profiler, this impl includes the helper kernel + main kernel, while torch's includes its setup kernel and cutlass_cpp main kernel. - -* API changes - - ab_dtype is deprecated in make_trivial_tiled_mma and make_blockscaled_trivial_tiled_mma from blackwell_helpers.py. Please specify a_dtype and b_dtype separately instead. + - Improvements on linter support with more type ignores cleaned up + - Improvements on tvm-ffi CUDA runtime error diagnostics + - Improvements on dataclass support for TVM-FFI + - Fixed a regression on compilation time + - Enhancement on compile time checks to reject mis-aligned smem operand for TMA + - Long-deprecated API clean-up, including: + - cute.core.ThrMma, please use cute.ThrMma instead + - cute.core.ThrCopy, please use cute.ThrCopy instead + - cute.make_fragment, please use cute.make_rmem_tensor instead ## CUTLASS C++ -* Add 2SM MMA instruction support to mixed TMA+CpAsync SM100 vanilla GEMM kernels. - - Mixed TMA+CpAsync can now accept static, but non trivial cluster shapes. - - Uses TMA multicast for A tile when using non-trivial cluster size along N mode. - - Uses an additional barrier (mma_trampoline_barrier) to track cp.async arrivals in both CTAs. - - Changes included in [example 92](https://github.com/NVIDIA/cutlass/tree/main/examples/92_blackwell_moe_gemm). -* Add support for 128x32xK and 128x64xK tile sizes for SM120 blockscaled MMA collective builders, yielding up to 30% performance improvement on Blackwell SM121 related kernels. -* Add static load to tensor memory support, included in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/). -* Use 64-bit adds for SM100 MMA descriptor offsets and reduce move instructions for improved code generation. -* Add [example 95](https://github.com/NVIDIA/cutlass/tree/main/examples/95_blackwell_gemm_green_context) to support green context SM partition - - Enables launching GEMM on stream with partial SM allocation. -* Add [Snake](https://github.com/NVIDIA/cutlass/blob/main/test/unit/epilogue/thread/activation.cu#L409) activation functor for EVT. -* Fix SM100 F8F6F4 SS MMA (1SM and 2SM) traits to use typed op templates. -* Add UE8M0 (uniform exponent distribution) initialization support in tensor fill utilities. -* Add `cvt.rn.bf16x2.e4m3x2` conversion instruction support to `numeric_conversion.h`. -* Update [example 93](https://github.com/NVIDIA/cutlass/tree/main/examples/93_blackwell_low_latency_gqa) with paged KV cache support for Blackwell low-latency GQA. -* Fix some kernel issues: - - Fix l2_capacity=0 handling in Blackwell SM100/SM120 kernel templates - - Fix CUTLASS clang build issues - - Remove `PipelineStorage` shadowing in SM100 complex epilogue - - Fix build issue in SM90 epilogue fusion visitor TMA warpspecialized - - Fix missing convert fucntion in EVT for fp4 kernels -* Fix some profiler issues: - - Add missing reference kernels for blockwise GEMM profiler. - - Avoid instantiate 2sm tma kernels where ctaN is none power of 64 when ctaN > 128 in profiler. +* Add [example 113](https://github.com/NVIDIA/cutlass/tree/main/examples/113_hopper_gemm_activation_fusion) for Hopper GEMM with activation fusion. + - Supports standard and gated activations (e.g., SiLu) with fp8 and fp16 inputs. + - Covers both regular GEMM and grouped GEMM variants. +* Improve SM90 grouped/ptr-array GEMM with EVT support. + - Adds the EVT (Epilogue Visitor Tree) plumbing required to do activation, bias, and auxiliary-tensor fusion inside SM90 grouped and ptr-array GEMM kernels. +* Fix `DescriptorIterator::operator+` in `mma_traits_sm100.hpp` to use 32-bit arithmetic on CUDA toolkit version <= 13.3, preserving the high half of the smem descriptor. Note: CUTLASS 4.x builds are known to be down on Windows platforms for all CUDA toolkits. CUTLASS team is working on a fix. diff --git a/examples/113_hopper_gemm_activation_fusion/113_hopper_gemm_fused_act.cu b/examples/113_hopper_gemm_activation_fusion/113_hopper_gemm_fused_act.cu new file mode 100644 index 000000000..d8b7c7413 --- /dev/null +++ b/examples/113_hopper_gemm_activation_fusion/113_hopper_gemm_fused_act.cu @@ -0,0 +1,533 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper GEMM with activation fusion example +*/ + +#include +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/fusion/operations.hpp" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "helper.h" +#include "options.hpp" +#include "utils.hpp" +#include "sm90_lin_comb_elt_act_scaled.hpp" +#include "activation_kernel.cuh" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if 0 +template +using ActivationFn = cutlass::epilogue::thread::ReLu; +#elif 1 +template +using ActivationFn = cutlass::epilogue::thread::SiLu; +#else +template +using ActivationFn = cutlass::epilogue::thread::Identity; +#endif + +bool constexpr IsFp8 = true; // whether to run with fp8 or fp16 input/output +bool constexpr Quantize = true; // whether to quantize output with a per-tensor scale factor +bool constexpr ExactMode = false; // whether to reproduce unfused dual gemm+activation exactly +bool constexpr BiasBroadcast = true; // whether bias is broadcast along columns in each group +bool constexpr Pingpong = true; // whether to use pingpong schedule + +// A matrix configuration +using ElementA = conditional_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = conditional_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +using ElementD = conditional_t; // Element type for C and D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for internal accumulation +using ElementScalar = float; // Element type for internal accumulation +using ElementIntermediate = cutlass::half_t; // Element type of intermediate result between GEMM and bias+activation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using EpiTileShape = cutlass::epilogue::collective::EpilogueTileAuto; // Epilogue sub-tile shape +using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster +using TileShapeK = Int<128 * 8 / sizeof_bits::value>; + +using KernelScheduleCooperative = conditional_t(), + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum, + cutlass::gemm::KernelTmaWarpSpecializedCooperative>; + +using KernelSchedulePingpong = conditional_t(), + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, + cutlass::gemm::KernelTmaWarpSpecializedPingpong>; + +using KernelSchedule = conditional_t; +using EpilogueSchedule = conditional_t; +using TileShape = conditional_t, Shape<_128,_256,TileShapeK>>; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + EpiTileShape, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule, + cutlass::epilogue::fusion::AccCastLinCombEltActScale< + Quantize, + ActivationFn, + ElementD, + ElementCompute, + ElementC, + ElementScalar, + conditional_t + > + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Reference device GEMM implementation type +using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + EpiTileShape, + ElementAccumulator, ElementCompute, + void, LayoutC, AlignmentC, + ElementIntermediate, LayoutD, AlignmentD, + EpilogueSchedule, + cutlass::epilogue::fusion::ScaledAcc + >::CollectiveOp; + +using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopRef, + CollectiveEpilogueRef +>; + +using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideA = typename GemmKernel::StrideA; +using StrideB = typename GemmKernel::StrideB; +using StrideC = typename GemmKernel::StrideC; +using StrideD = typename GemmKernel::StrideD; + +// +// Data members +// + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; +uint64_t seed; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_D_ref; +cutlass::DeviceAllocation block_D_ref_gemm; +cutlass::DeviceAllocation offset_col_D; +cutlass::DeviceAllocation block_scale; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options : GemmOptionsBase { + + using Base = GemmOptionsBase; + + float alpha = 1.f, beta = 0.f; + int m = 10240, n = 2048, k = 8192, l = 10; + + // Parses the command line + void parse(cutlass::CommandLine const& cmd) { + Base::parse(cmd); + cmd.get_cmd_line_argument("alpha", alpha); + cmd.get_cmd_line_argument("beta", beta); + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << program_path << "\n" + "\n" + " Hopper GEMM with fused activation function.\n" + "\n" + "Options:\n" + "\n" + " --help If specified, displays this usage statement\n\n" + " --m= Sets the M extent of the GEMM\n" + " --n= Sets the N extent of the GEMM\n" + " --k= Sets the K extent of the GEMM\n" + " --l= Sets the L extent of the GEMM\n" + " --alpha= Epilogue scalar alpha\n" + " --beta= Epilogue scalar beta\n" + " --raster= CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n" + " --swizzle= CTA Rasterization swizzle\n" + " --warmup= Number of warmup iterations to perform.\n" + " --iterations= Number of profiling iterations to perform.\n" + " --verbose Verbose mode (output detailed verification result)\n" + " --verify= Verification (correctness check) on/off\n" + " --sms Number of SMs to run the GEMMs on\n" + " --device Device index\n" + "\n" + "Examples:\n" + "\n" + << program_path << " --m=1024 --n=512 --k=1024 --l=10 --alpha=2 --beta=0.707\n"; + + return out; + } + + /// Compute total number of floating point operations + double total_flops() const { + // Two flops per multiply-add + return uint64_t(2) * m * n * k * l; + } +}; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + auto [M, N, K, L] = make_tuple(options.m, options.n, options.k, options.l); + auto NC = BiasBroadcast ? 1 : N; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, L}); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, L}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {M, NC, L}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, L}); + + block_A.reset(M * K * L); + block_B.reset(N * K * L); + block_C.reset(M * NC * L); + block_D.reset(M * N * L); + block_D_ref.reset(M * N * L); + block_D_ref_gemm.reset(M * N * L); + block_scale.reset(1); + + if constexpr (BiasBroadcast) { + get<1>(stride_C) = 0; + } + + std::vector offset_col_D_host(options.l + 1); + std::iota(offset_col_D_host.begin(), offset_col_D_host.end(), 0ll); + std::transform(offset_col_D_host.begin(), offset_col_D_host.end(), offset_col_D_host.begin(), [&](auto i) { return i * options.n; }); + offset_col_D.reset(options.l + 1); + offset_col_D.copy_from_host(offset_col_D_host.data()); + + cutlass::reference::device::BlockFillRandom(block_A.get(), block_A.size(), 2024, options.dist_a); + cutlass::reference::device::BlockFillRandom(block_B.get(), block_B.size(), 2025, options.dist_b); + cutlass::reference::device::BlockFillRandom(block_C.get(), block_C.size(), 2026, options.dist_c); + cutlass::reference::device::BlockFillRandomUniform(block_scale.get(), block_scale.size(), 2027, 0.5, 1.0); +} + +template +auto args_from_options_common(const Options &options) +{ + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = options.sm_count > 0 + ? options.sm_count + : cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + typename GemmT::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args{}; + + fusion_args.alpha = options.alpha; + + return make_tuple(fusion_args, hw_info); +} + +template +typename GemmT::Arguments +args_from_options(const Options &options); + +template <> +typename Gemm::Arguments +args_from_options(Options const& options) +{ + auto [fusion_args, hw_info] = args_from_options_common(options); + fusion_args.beta = options.beta; + fusion_args.scale_ptr = Quantize ? block_scale.get() : nullptr; + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {fusion_args, block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info, + {options.swizzle, options.raster} + }; + + return arguments; +} + +template <> +typename GemmRef::Arguments +args_from_options(Options const& options) +{ + auto [fusion_args, hw_info] = args_from_options_common(options); + + typename GemmRef::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {fusion_args, nullptr, stride_C, block_D_ref_gemm.get(), stride_D}, + hw_info, + {options.swizzle, options.raster} + }; + + return arguments; +} + +bool verify(const Options &options) { + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = false; + if constexpr (ExactMode) { + passed = cutlass::reference::device::BlockCompareEqual(block_D_ref.get(), block_D.get(), block_D.size()); + } + else { + passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_D_ref.get(), block_D.get(), block_D.size(), ElementD(options.tolerance), ElementD(options.nonzero_floor)); + } + + if (!passed && options.verbose) { + print("D reference: "); print_device_tensor(make_tensor(block_D_ref.get(), make_shape(options.m, options.n, options.l), stride_D)); + print("D computed: "); print_device_tensor(make_tensor(block_D.get(), make_shape(options.m, options.n, options.l), stride_D)); + } + + return passed; +} + +/// Execute a given example GEMM computation +bool run(Options& options) +{ + if (options.beta != 1.f && options.beta != 0.f) { + throw std::runtime_error("Specifying beta != 0/1 is not supported by verification kernel"); + } + + initialize(options); + + std::cout << "Problem Size: " << shape_string(make_tuple(options.m, options.n, options.k, options.l)) << std::endl; + std::cout << "Data types: " << problem_desc_string() << std::endl; + std::cout << "Activation function: " << activation_func_string() << std::endl; + std::cout << "Kernel schedule: " << kernel_schedule_string() << std::endl; + std::cout << "GEMM tile shape: " << shape_string(TileShape{}) << std::endl; + std::cout << "Epi tile shape: " << epilogue_tile_string(EpiTileShape{}) << std::endl; + std::cout << "Cluster shape: " << shape_string(ClusterShape{}) << std::endl; + std::cout << "Rasterization: " << options.raster_string() << " with a maximum CTA swizzle of " << options.swizzle << std::endl; + std::cout << "Options: Quantize = " << Quantize << ", Exact = " << ExactMode << ", BiasBroadcast = " << BiasBroadcast << std::endl; + + Runner gemm(args_from_options(options)); + Runner gemm_ref(args_from_options(options)); + + auto run_fused = [&](){ gemm.run(); }; + auto run_ref_gemm = [&](){ gemm_ref.run(); }; + auto run_activation = [&](){ + do_activation( + block_D_ref.get(), + block_D_ref_gemm.get(), + Quantize ? block_scale.get() : static_cast(nullptr), + options.beta != 0.f ? block_C.get() : static_cast(nullptr), + BiasBroadcast, + offset_col_D.get(), + options.l, + options.m, + options.n * options.l, + false); + }; + auto run_unfused = [&](){ run_ref_gemm(); run_activation(); }; + + run_fused(); + CUDA_CHECK(cudaDeviceSynchronize()); + + // Correctness check + bool passed = true; + if (options.verify) { + run_unfused(); + CUDA_CHECK(cudaDeviceSynchronize()); + + passed = verify(options); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + } + + // Run profiling loop + if (options.iterations > 0) + { + auto benchmark = [&](auto name, auto func) + { + BenchmarkResult result = run_benchmark(func, options.warmup, options.iterations); + double avg_tflops = double(options.total_flops()) / result.avg_runtime_ms / 1e9; // FLOP/ms -> TFLOP/s + printf(options.csv ? "%s,%.3f,%.0f\n" : "%20s %20.3f %20.0f\n", + name, result.avg_runtime_ms, avg_tflops); + }; + printf(options.csv ? "%s,%s,%s\n" : "%20s %20s %20s\n", + "Kernel", "Runtime (ms)", "Throughput (Tflop/s)"); + benchmark("Fused", run_fused); + benchmark("Unfused", run_unfused); + benchmark("GEMM only", run_ref_gemm); + } + + return passed; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + try { + Options options; + cutlass::CommandLine cmd(argc, args); + options.parse(cmd); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return EXIT_SUCCESS; + } + + if (options.device >= 0) { + CUDA_CHECK(cudaSetDevice(options.device)); + } + else { + CUDA_CHECK(cudaGetDevice(&options.device)); + } + + cudaDeviceProp props; + CUDA_CHECK(cudaGetDeviceProperties(&props, options.device)); + if (props.major != 9 || props.minor != 0) { + std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; + return EXIT_SUCCESS; + } + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + if (!run(options)) { + return EXIT_FAILURE; + } +#endif + } + catch (std::exception const& e) { + std::cerr << e.what() << std::endl; + return EXIT_FAILURE; + } + + return EXIT_SUCCESS; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/113_hopper_gemm_activation_fusion/113_hopper_gemm_fused_gated_act.cu b/examples/113_hopper_gemm_activation_fusion/113_hopper_gemm_fused_gated_act.cu new file mode 100644 index 000000000..f4046c2d0 --- /dev/null +++ b/examples/113_hopper_gemm_activation_fusion/113_hopper_gemm_fused_gated_act.cu @@ -0,0 +1,559 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper GEMM with activation fusion example +*/ + +#include +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/fusion/operations.hpp" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "helper.h" +#include "options.hpp" +#include "utils.hpp" +#include "gated_stride.hpp" +#include "gated_builder.hpp" +#include "activation_kernel.cuh" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if 0 +template +using ActivationFn = cutlass::epilogue::thread::ReLu; +#elif 1 +template +using ActivationFn = cutlass::epilogue::thread::SiLu; +#else +template +using ActivationFn = cutlass::epilogue::thread::Identity; +#endif + +bool constexpr IsFp8 = true; // whether to run with fp8 or fp16 input/output +bool constexpr Quantize = true; // whether to quantize output with a per-tensor scale factor +bool constexpr ExactMode = false; // whether to reproduce unfused dual gemm+activation exactly +bool constexpr BiasBroadcast = true; // whether bias is broadcast along columns in each group +bool constexpr Pingpong = true; // whether to use pingpong schedule + +using ProblemShape = Shape; +using GatedProblemShape = decltype(cutlass::sm90_make_gated_shape<0>(ProblemShape{})); + +// A matrix configuration +using ElementA = conditional_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = conditional_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +using ElementD = conditional_t; // Element type for C and D matrix operands +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for internal accumulation +using ElementScalar = float; // Element type for internal accumulation +using ElementIntermediate = cutlass::half_t; // Element type of intermediate result between GEMM and bias+activation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using EpiTileShape = cutlass::epilogue::collective::EpilogueTileAuto; // Epilogue sub-tile shape +using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster +using TileShapeK = Int<128 * 8 / sizeof_bits::value>; + +using KernelScheduleCooperative = conditional_t(), + cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum, + cutlass::gemm::KernelTmaWarpSpecializedCooperative>; + +using KernelSchedulePingpong = conditional_t(), + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, + cutlass::gemm::KernelTmaWarpSpecializedPingpong>; + +using KernelSchedule = conditional_t; +using EpilogueSchedule = conditional_t; +using TileShape = conditional_t, Shape<_128,_256,TileShapeK>>; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::Sm90CollectiveBuilderGated< + OperatorClass, + TileShape, ClusterShape, + EpiTileShape, + ElementAccumulator, ElementCompute, ElementScalar, + conditional_t, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule, + ActivationFn, + Quantize +>::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::Sm90CollectiveBuilderGated< + OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule +>::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + GatedProblemShape, + CollectiveMainloop, + CollectiveEpilogue +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Reference device GEMM implementation type +using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + EpiTileShape, + ElementAccumulator, ElementCompute, + void, LayoutC, AlignmentC, + ElementIntermediate, LayoutD, AlignmentD, + EpilogueSchedule, + cutlass::epilogue::fusion::ScaledAcc + >::CollectiveOp; + +using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloopRef, + CollectiveEpilogueRef +>; + +using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideA = GemmKernel::StrideA; +using StrideB = GemmKernel::StrideB; +using StrideC = GemmKernel::StrideC; +using StrideD = GemmKernel::CollectiveEpilogue::FusionCallbacks::Operation::GmemLayoutTagAux; + +using StrideARef = GemmKernelRef::StrideA; +using StrideBRef = GemmKernelRef::StrideB; +using StrideCRef = GemmKernelRef::StrideC; +using StrideDRef = GemmKernelRef::StrideD; + +// +// Data members +// + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; + +StrideARef stride_A_ref; +StrideBRef stride_B_ref; +StrideCRef stride_C_ref; +StrideDRef stride_D_ref; +StrideDRef stride_D_ref_gemm; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_D_ref; +cutlass::DeviceAllocation block_D_ref_gemm; +cutlass::DeviceAllocation offset_col_D; +cutlass::DeviceAllocation block_scale; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options : GemmOptionsBase { + + using Base = GemmOptionsBase; + + float alpha = 1.f, beta = 0.f; + int m = 10240, n = 2048, k = 8192, l = 10; + + // Parses the command line + void parse(cutlass::CommandLine const& cmd) { + Base::parse(cmd); + cmd.get_cmd_line_argument("alpha", alpha); + cmd.get_cmd_line_argument("beta", beta); + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << program_path << "\n" + "\n" + " Hopper GEMM with fused activation function.\n" + "\n" + "Options:\n" + "\n" + " --help If specified, displays this usage statement\n\n" + " --m= Sets the M extent of the GEMM\n" + " --n= Sets the N extent of the GEMM\n" + " --k= Sets the K extent of the GEMM\n" + " --l= Sets the L extent of the GEMM\n" + " --alpha= Epilogue scalar alpha\n" + " --beta= Epilogue scalar beta\n" + " --raster= CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n" + " --swizzle= CTA Rasterization swizzle\n" + " --warmup= Number of warmup iterations to perform.\n" + " --iterations= Number of profiling iterations to perform.\n" + " --verbose Verbose mode (output detailed verification result)\n" + " --verify= Verification (correctness check) on/off\n" + " --sms Number of SMs to run the GEMMs on\n" + " --device Device index\n" + "\n" + "Examples:\n" + "\n" + << program_path << " --m=1024 --n=512 --k=1024 --l=10 --alpha=2 --beta=0.707\n"; + + return out; + } + + /// Compute total number of floating point operations + double total_flops() const { + // Two flops per multiply-add + return uint64_t(2) * m * n * k * l; + } +}; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + auto [M, N, K, L] = make_tuple(options.m, options.n, options.k, options.l); + auto NC = BiasBroadcast ? 1 : N; + + namespace cd = cutlass::gemm::collective::detail; + + stride_A = cutlass::sm90_make_gated_packed_stride(StrideA{}, {M, K, L}); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, L}); + stride_C = cutlass::sm90_make_gated_packed_stride(StrideC{}, {M, NC, L}); + stride_D = cutlass::sm90_make_gated_packed_stride(StrideD{}, {M/2, N, L}); + + stride_A_ref = cutlass::make_cute_packed_stride(StrideARef{}, {M, K, L}); + stride_B_ref = cutlass::make_cute_packed_stride(StrideBRef{}, {N, K, L}); + stride_C_ref = cutlass::make_cute_packed_stride(StrideCRef{}, {M, NC, L}); + stride_D_ref = cutlass::make_cute_packed_stride(StrideDRef{}, {M/2, N, L}); + stride_D_ref_gemm = cutlass::make_cute_packed_stride(StrideDRef{}, {M, N, L}); + + block_A.reset(M * K * L); + block_B.reset(N * K * L); + block_C.reset(M * NC * L); + block_D.reset(M/2 * N * L); + block_D_ref.reset(M/2 * N * L); + block_D_ref_gemm.reset(M * N * L); + block_scale.reset(1); + + if constexpr (BiasBroadcast) { + get<1>(stride_C) = 0; + get<1>(stride_C_ref) = 0; + } + + std::vector offset_col_D_host(options.l + 1); + std::iota(offset_col_D_host.begin(), offset_col_D_host.end(), 0ll); + std::transform(offset_col_D_host.begin(), offset_col_D_host.end(), offset_col_D_host.begin(), [&](auto i) { return i * options.n; }); + offset_col_D.reset(options.l + 1); + offset_col_D.copy_from_host(offset_col_D_host.data()); + + cutlass::reference::device::BlockFillRandom(block_A.get(), block_A.size(), 2024, options.dist_a); + cutlass::reference::device::BlockFillRandom(block_B.get(), block_B.size(), 2025, options.dist_b); + cutlass::reference::device::BlockFillRandom(block_C.get(), block_C.size(), 2026, options.dist_c); + cutlass::reference::device::BlockFillRandomUniform(block_scale.get(), block_scale.size(), 2027, 0.5, 1.0); +} + +template +auto args_from_options_common(const Options &options) +{ + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = options.sm_count > 0 + ? options.sm_count + : cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + typename GemmT::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args{}; + + fusion_args.alpha = options.alpha; + + return make_tuple(fusion_args, hw_info); +} + +template +typename GemmT::Arguments +args_from_options(const Options &options); + +template <> +typename Gemm::Arguments +args_from_options(Options const& options) +{ + auto [fusion_args, hw_info] = args_from_options_common(options); + + fusion_args.beta = options.beta; + fusion_args.scale_ptr = Quantize ? block_scale.get() : nullptr; + fusion_args.ptr_D = block_D.get(); + fusion_args.dD = stride_D; + + namespace cd = cutlass::gemm::collective::detail; + auto problem_shape = cutlass::sm90_make_gated_shape<0>(make_shape(options.m, options.n, options.k, options.l)); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_shape, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {fusion_args, block_C.get(), stride_C, nullptr, {}}, + hw_info, + {options.swizzle, options.raster} + }; + + return arguments; +} + +template <> +typename GemmRef::Arguments +args_from_options(Options const& options) +{ + auto [fusion_args, hw_info] = args_from_options_common(options); + + auto problem_shape = make_shape(options.m, options.n, options.k, options.l); + + typename GemmRef::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_shape, + {block_A.get(), stride_A_ref, block_B.get(), stride_B_ref}, + {{options.alpha, 0}, nullptr, stride_C_ref, block_D_ref_gemm.get(), stride_D_ref_gemm}, + hw_info, + {options.swizzle, options.raster} + }; + + return arguments; +} + +bool verify(const Options &options) { + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = false; + if constexpr (ExactMode) { + passed = cutlass::reference::device::BlockCompareEqual(block_D_ref.get(), block_D.get(), block_D.size()); + } + else { + passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_D_ref.get(), block_D.get(), block_D.size(), ElementD(options.tolerance), ElementD(options.nonzero_floor)); + } + + if (!passed && options.verbose) { + auto [M,N,K,L] = make_shape(options.m, options.n, options.k, options.l); + print("D reference: "); print_device_tensor(make_tensor(block_D_ref.get(), make_shape(M/2,N,L), stride_D_ref)); + print("D computed: "); print_device_tensor(make_tensor(block_D.get(), make_shape(M/2,N,L), stride_D_ref)); + } + + return passed; +} + +/// Execute a given example GEMM computation +bool run(Options& options) +{ + if (options.beta != 1.f && options.beta != 0.f) { + throw std::runtime_error("Specifying beta != 0/1 is not supported by verification kernel"); + } + + initialize(options); + + std::cout << "Problem Size: " << shape_string(make_tuple(options.m, options.n, options.k, options.l)) << std::endl; + std::cout << "Data types: " << problem_desc_string() << std::endl; + std::cout << "Activation function: " << activation_func_string() << std::endl; + std::cout << "Kernel schedule: " << kernel_schedule_string() << std::endl; + std::cout << "GEMM tile shape: " << shape_string(TileShape{}) << std::endl; + std::cout << "Epi tile shape: " << epilogue_tile_string(EpiTileShape{}) << std::endl; + std::cout << "Cluster shape: " << shape_string(ClusterShape{}) << std::endl; + std::cout << "Rasterization: " << options.raster_string() << " with a maximum CTA swizzle of " << options.swizzle << std::endl; + std::cout << "Options: Quantize = " << Quantize << ", Exact = " << ExactMode << ", BiasBroadcast = " << BiasBroadcast << std::endl; + + Runner gemm(args_from_options(options)); + Runner gemm_ref(args_from_options(options)); + + auto run_fused = [&](){ gemm.run(); }; + auto run_ref_gemm = [&](){ gemm_ref.run(); }; + auto run_activation = [&](){ + do_activation( + block_D_ref.get(), + block_D_ref_gemm.get(), + Quantize ? block_scale.get() : static_cast(nullptr), + options.beta != 0.f ? block_C.get() : static_cast(nullptr), + BiasBroadcast, + offset_col_D.get(), + options.l, + options.m / 2, + options.n * options.l, + true); + }; + auto run_unfused = [&](){ run_ref_gemm(); run_activation(); }; + + run_fused(); + CUDA_CHECK(cudaDeviceSynchronize()); + + // Correctness check + bool passed = true; + if (options.verify) { + run_unfused(); + CUDA_CHECK(cudaDeviceSynchronize()); + + passed = verify(options); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + } + + // Run profiling loop + if (options.iterations > 0) + { + auto benchmark = [&](auto name, auto func) + { + BenchmarkResult result = run_benchmark(func, options.warmup, options.iterations); + double avg_tflops = double(options.total_flops()) / result.avg_runtime_ms / 1e9; // FLOP/ms -> TFLOP/s + printf(options.csv ? "%s,%.3f,%.0f\n" : "%20s %20.3f %20.0f\n", + name, result.avg_runtime_ms, avg_tflops); + }; + printf(options.csv ? "%s,%s,%s\n" : "%20s %20s %20s\n", + "Kernel", "Runtime (ms)", "Throughput (Tflop/s)"); + benchmark("Fused", run_fused); + benchmark("Unfused", run_unfused); + benchmark("GEMM only", run_ref_gemm); + } + + return passed; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + try { + Options options; + cutlass::CommandLine cmd(argc, args); + options.parse(cmd); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return EXIT_SUCCESS; + } + + if (options.device >= 0) { + CUDA_CHECK(cudaSetDevice(options.device)); + } + else { + CUDA_CHECK(cudaGetDevice(&options.device)); + } + + cudaDeviceProp props; + CUDA_CHECK(cudaGetDeviceProperties(&props, options.device)); + if (props.major != 9 || props.minor != 0) { + std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; + return EXIT_SUCCESS; + } + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + if (!run(options)) { + return EXIT_FAILURE; + } +#endif + } + catch (std::exception const& e) { + std::cerr << e.what() << std::endl; + return EXIT_FAILURE; + } + + return EXIT_SUCCESS; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/113_hopper_gemm_activation_fusion/113_hopper_grouped_gemm_fused_act.cu b/examples/113_hopper_gemm_activation_fusion/113_hopper_grouped_gemm_fused_act.cu new file mode 100644 index 000000000..81095fd7b --- /dev/null +++ b/examples/113_hopper_gemm_activation_fusion/113_hopper_grouped_gemm_fused_act.cu @@ -0,0 +1,655 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief +*/ + +#include +#include +#include +#include +#include + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/fusion/operations.hpp" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "helper.h" +#include "options.hpp" +#include "utils.hpp" +#include "sm90_lin_comb_elt_act_scaled.hpp" +#include "activation_kernel.cuh" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if 0 +template +using ActivationFn = cutlass::epilogue::thread::ReLu; +#elif 1 +template +using ActivationFn = cutlass::epilogue::thread::SiLu; +#else +template +using ActivationFn = cutlass::epilogue::thread::Identity; +#endif + +bool constexpr IsFp8 = true; // whether to run with fp8 or fp16 input/output +bool constexpr Quantize = true; // whether to quantize output with a per-tensor scale factor +bool constexpr ExactMode = false; // whether to reproduce unfused dual gemm+activation exactly +bool constexpr BiasBroadcast = true; // whether bias is broadcast along columns in each group +bool constexpr Pingpong = true; // whether to use pingpong schedule + +using ProblemShape = Shape; // per group +using GroupProblemShape = cutlass::gemm::GroupProblemShape; + +// A matrix configuration +using ElementA = conditional_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = conditional_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Alignment of B matrix in units of elements (up to 16 bytes) + +// C matrix configuration +using ElementC = cutlass::half_t; // Element type for C matrix operand +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operand +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Alignment of C matrix in units of elements (up to 16 bytes) + +// D matrix configuration +using ElementD = conditional_t; // Element type for D matrix operand +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operand +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Alignment of D matrix in units of elements (up to 16 bytes) + +int constexpr AlignmentM = max(make_tuple(cutlass::gemm::detail::is_mn_major_A() ? AlignmentA : 1, + cutlass::gemm::detail::is_m_major_C() ? AlignmentC : 1, + cutlass::gemm::detail::is_m_major_C() ? AlignmentD : 1)); +int constexpr AlignmentN = max(make_tuple(cutlass::gemm::detail::is_mn_major_B() ? AlignmentB : 1, + cutlass::gemm::detail::is_n_major_C() ? AlignmentC : 1, + cutlass::gemm::detail::is_n_major_C() ? AlignmentD : 1)); +int constexpr AlignmentK = max(make_tuple(cutlass::gemm::detail::is_k_major_A() ? AlignmentA : 1, + cutlass::gemm::detail::is_k_major_B() ? AlignmentB : 1)); + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for epilogue compute +using ElementScalar = float; // Element type for scalar values (alpha, beta) +using ElementIntermediate = cutlass::half_t; // Element type of intermediate result between GEMM and bias+activation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using EpiTileShape = cutlass::epilogue::collective::EpilogueTileAuto; +using ClusterShape = Shape<_1,_2,_1>; +using TileShapeK = Int<128 * 8 / sizeof_bits::value>; + +using KernelScheduleCooperative = conditional_t(), + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum, + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative>; + +using KernelSchedulePingpong = conditional_t(), + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum, + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong>; + +using KernelSchedule = conditional_t; +using EpilogueSchedule = conditional_t; +using TileShape = conditional_t, Shape<_128,_256,TileShapeK>>; + +// GEMM setup + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + EpiTileShape, + ElementAccumulator, ElementCompute, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutD *, AlignmentD, + EpilogueSchedule, + cutlass::epilogue::fusion::AccCastLinCombEltActScale< + Quantize, + ActivationFn, + ElementD, + ElementCompute, + ElementC, + ElementScalar, + conditional_t + > + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + GroupProblemShape, + CollectiveMainloop, + CollectiveEpilogue +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Reference GEMM setup + +using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + TileShape, ClusterShape, EpiTileShape, + ElementAccumulator, ElementCompute, + void, LayoutC *, AlignmentC, + ElementIntermediate, LayoutD *, AlignmentD, + EpilogueSchedule, + cutlass::epilogue::fusion::ScaledAcc + >::CollectiveOp; + +using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogueRef::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal< + GroupProblemShape, + CollectiveMainloopRef, + CollectiveEpilogueRef +>; + +using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideA = GemmKernel::InternalStrideA; +using StrideB = GemmKernel::InternalStrideB; +using StrideC = GemmKernel::InternalStrideC; +using StrideD = GemmKernel::InternalStrideD; + +// Host-side allocations +std::vector problem_shapes_host; + +std::vector offset_A; +std::vector offset_B; +std::vector offset_C; +std::vector offset_D; +std::vector offset_col_D_host; + +std::vector stride_A_host; +std::vector stride_B_host; +std::vector stride_C_host; +std::vector stride_D_host; + +std::vector alpha_host; +std::vector beta_host; +std::vector scale_host; + +// Device-side allocations +cutlass::DeviceAllocation problem_shapes; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_D_ref; +cutlass::DeviceAllocation block_D_ref_gemm; + +cutlass::DeviceAllocation ptr_A; +cutlass::DeviceAllocation ptr_B; +cutlass::DeviceAllocation ptr_C; +cutlass::DeviceAllocation ptr_D; +cutlass::DeviceAllocation ptr_D_ref; +cutlass::DeviceAllocation offset_col_D; + +cutlass::DeviceAllocation stride_A; +cutlass::DeviceAllocation stride_B; +cutlass::DeviceAllocation stride_C; +cutlass::DeviceAllocation stride_D; + +// Note, this is an array of pointers to alpha and beta scaling values per group +cutlass::DeviceAllocation ptr_alpha; +cutlass::DeviceAllocation ptr_beta; +cutlass::DeviceAllocation block_alpha; +cutlass::DeviceAllocation block_beta; +cutlass::DeviceAllocation block_scale; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +using Options = GroupedGemmOptions; + +/// Allocates device-side data +void allocate(const Options &options) { + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + int64_t total_cols_D = 0; + + for (int32_t i = 0; i < options.groups; ++i) { + + cutlass::gemm::GemmCoord const& problem_size = options.problem_sizes[i]; + auto problem_shape_ref = make_shape(problem_size.m(), problem_size.n(), problem_size.k()); + auto [M, N, K] = problem_shape_ref; + + problem_shapes_host.push_back(make_shape(M, N, K)); + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + offset_col_D_host.push_back(total_cols_D); + + int64_t elements_A = M * K; + int64_t elements_B = K * N; + int64_t elements_C = M * N; + int64_t elements_D = M * N; + + stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1})); + stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); + stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1})); + stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1})); + + if constexpr (BiasBroadcast) { + get<1>(stride_C_host.back()) = 0; + elements_C = M; + } + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + total_cols_D += N; + } + offset_col_D_host.push_back(total_cols_D); + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + block_D_ref.reset(total_elements_D); + block_D_ref_gemm.reset(total_elements_D); + block_alpha.reset(options.groups); + block_beta.reset(options.groups); + block_scale.reset(1); + + problem_shapes.reset(options.groups); + + ptr_A.reset(options.groups); + ptr_B.reset(options.groups); + ptr_C.reset(options.groups); + ptr_D.reset(options.groups); + ptr_D_ref.reset(options.groups); + ptr_alpha.reset(options.groups); + ptr_beta.reset(options.groups); + + stride_A.reset(options.groups); + stride_B.reset(options.groups); + stride_C.reset(options.groups); + stride_D.reset(options.groups); + + offset_col_D.reset(options.groups + 1); +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + // + // Assign pointers + // + + std::vector ptr_A_host(options.groups); + std::vector ptr_B_host(options.groups); + std::vector ptr_C_host(options.groups); + std::vector ptr_D_host(options.groups); + std::vector ptr_D_ref_host(options.groups); + std::vector ptr_alpha_host(options.groups); + std::vector ptr_beta_host(options.groups); + + std::mt19937 rng(2024); + std::uniform_real_distribution alpha_dist(0.5, 2.0); + std::uniform_real_distribution beta_dist(1.0, 1.0); // (0.0, 4.0); + std::uniform_real_distribution scale_dist(0.5, 1.0); + + for (int32_t i = 0; i < options.groups; ++i) { + ptr_A_host[i] = block_A.get() + offset_A[i]; + ptr_B_host[i] = block_B.get() + offset_B[i]; + ptr_C_host[i] = block_C.get() + offset_C[i]; + ptr_D_host[i] = block_D.get() + offset_D[i]; + ptr_D_ref_host[i] = block_D_ref_gemm.get() + offset_D[i]; + alpha_host.push_back(options.alpha == FLT_MAX ? alpha_dist(rng) : options.alpha); + beta_host.push_back(options.beta == FLT_MAX ? beta_dist(rng) : options.beta); + ptr_alpha_host[i] = block_alpha.get() + i; + ptr_beta_host[i] = block_beta.get() + i; + } + scale_host.push_back(scale_dist(rng)); + + problem_shapes.copy_from_host(problem_shapes_host.data()); + + ptr_A.copy_from_host(ptr_A_host.data()); + ptr_B.copy_from_host(ptr_B_host.data()); + ptr_C.copy_from_host(ptr_C_host.data()); + ptr_D.copy_from_host(ptr_D_host.data()); + ptr_D_ref.copy_from_host(ptr_D_ref_host.data()); + ptr_alpha.copy_from_host(ptr_alpha_host.data()); + ptr_beta.copy_from_host(ptr_beta_host.data()); + + stride_A.copy_from_host(stride_A_host.data()); + stride_B.copy_from_host(stride_B_host.data()); + stride_C.copy_from_host(stride_C_host.data()); + stride_D.copy_from_host(stride_D_host.data()); + + offset_col_D.copy_from_host(offset_col_D_host.data()); + + block_alpha.copy_from_host(alpha_host.data()); + block_beta.copy_from_host(beta_host.data()); + block_scale.copy_from_host(scale_host.data()); + + cutlass::reference::device::BlockFillRandom(block_A.get(), block_A.size(), 2024, options.dist_a); + cutlass::reference::device::BlockFillRandom(block_B.get(), block_B.size(), 2025, options.dist_b); + cutlass::reference::device::BlockFillRandom(block_C.get(), block_C.size(), 2026, options.dist_c); +} + +template +auto args_from_options_common(const Options &options, bool host_problem_shapes_available) +{ + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = options.sm_count > 0 + ? options.sm_count + : cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + typename GemmT::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args{}; + + if (options.alpha != FLT_MAX) { + fusion_args.alpha = options.alpha; + } + else { + fusion_args.alpha_ptr_array = ptr_alpha.get(); + fusion_args.dAlpha = {{},{},1}; + } + + auto problem_shapes_host_ptr = host_problem_shapes_available ? problem_shapes_host.data() : nullptr; + + return make_tuple(fusion_args, hw_info, problem_shapes_host_ptr); +} + +template +typename GemmT::Arguments +args_from_options(const Options &options, bool host_problem_shapes_available); + +template <> +Gemm::Arguments +args_from_options(const Options &options, bool host_problem_shapes_available) +{ + auto [fusion_args, hw_info, problem_shapes_host_ptr] = args_from_options_common(options, host_problem_shapes_available); + + fusion_args.beta = options.beta; + fusion_args.beta_ptr = block_beta.get(); + fusion_args.scale_ptr = Quantize ? block_scale.get() : nullptr; + + using RasterOrderOptions = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90GroupParams::RasterOrderOptions; + RasterOrderOptions raster = [&] { + switch (options.raster) { + case Options::RasterOrderOptions::Heuristic: return RasterOrderOptions::Heuristic; + case Options::RasterOrderOptions::AlongM: return RasterOrderOptions::AlongM; + case Options::RasterOrderOptions::AlongN: return RasterOrderOptions::AlongN; + default: return RasterOrderOptions::Heuristic; + } + }(); + + Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_shapes.get(), problem_shapes_host_ptr}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info, + {options.swizzle, raster} + }; + + return arguments; +} + +template <> +GemmRef::Arguments +args_from_options(const Options &options, bool host_problem_shapes_available) +{ + auto [fusion_args, hw_info, problem_shapes_host_ptr] = args_from_options_common(options, host_problem_shapes_available); + + using RasterOrderOptions = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90GroupParams::RasterOrderOptions; + RasterOrderOptions raster = [&] { + switch (options.raster) { + case Options::RasterOrderOptions::Heuristic: return RasterOrderOptions::Heuristic; + case Options::RasterOrderOptions::AlongM: return RasterOrderOptions::AlongM; + case Options::RasterOrderOptions::AlongN: return RasterOrderOptions::AlongN; + default: return RasterOrderOptions::Heuristic; + } + }(); + + GemmRef::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_shapes.get(), problem_shapes_host_ptr}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, + {fusion_args, nullptr, stride_C.get(), ptr_D_ref.get(), stride_D.get()}, + hw_info, + {options.swizzle, raster} + }; + + return arguments; +} + +bool verify(const Options &options) { + bool passed = true; + for (int32_t i = 0; i < options.groups; ++i) { + auto const& problem_size = options.problem_sizes[i]; + auto problem_shape = make_shape(problem_size.m(), problem_size.n(), problem_size.k()); + auto [M, N, K] = problem_shape; + + bool group_passed = false; + if constexpr (ExactMode) { + group_passed = cutlass::reference::device::BlockCompareEqual( + block_D_ref.get() + offset_D[i], block_D.get() + offset_D[i], M * N); + } + else { + group_passed = cutlass::reference::device::BlockCompareRelativelyEqual( + block_D_ref.get() + offset_D[i], block_D.get() + offset_D[i], M * N, ElementD(options.tolerance), ElementD(options.nonzero_floor)); + } + if (!group_passed && options.verbose) { + std::cout << "Group " << i << " failed" << std::endl; + print("D reference: "); print_device_tensor(make_tensor(block_D_ref.get() + offset_D[i], make_shape(M, N, 1), stride_D_host[i])); + print("D computed: "); print_device_tensor(make_tensor(block_D.get() + offset_D[i], make_shape(M, N, 1), stride_D_host[i])); + } + passed &= group_passed; + } + return passed; +} + +bool run(Options &options, bool host_problem_shapes_available = true) +{ + // Apply some restrictions on Grouped GEMM options + for (int i = 0; i < options.groups; ++i) { + if (options.problem_sizes[i].m() != options.problem_sizes[0].m()) { + throw std::runtime_error("Variable M problem size is not supported by verification kernel"); + } + } + if (options.beta != FLT_MAX && options.beta != 1.f && options.beta != 0.f) { + throw std::runtime_error("Specifying beta != 0/1 is not supported by verification kernel"); + } + + allocate(options); + initialize(options); + + std::cout << "Groups : " << options.groups << std::endl; + std::cout << "Problem Sizes, Alpha, Beta " << std::endl; + for (int32_t i = 0; i < options.groups; ++i) { + std::cout << " " << shape_string(make_tuple(options.problem_sizes[i].m(), options.problem_sizes[i].n(), options.problem_sizes[i].k())); + std::cout << ", " << alpha_host[i] << ", " << beta_host[i] << std::endl; + } + std::cout << "Data types: " << problem_desc_string() << std::endl; + std::cout << "Activation function: " << activation_func_string() << std::endl; + std::cout << "Kernel schedule: " << kernel_schedule_string() << std::endl; + std::cout << "GEMM tile shape: " << shape_string(TileShape{}) << std::endl; + std::cout << "Epi tile shape: " << epilogue_tile_string(EpiTileShape{}) << std::endl; + std::cout << "Cluster shape: " << shape_string(ClusterShape{}) << std::endl; + std::cout << "Rasterization: " << options.raster_string() << " with a maximum CTA swizzle of " << options.swizzle << std::endl; + std::cout << "Options: Quantize = " << Quantize << ", Exact = " << ExactMode << ", BiasBroadcast = " << BiasBroadcast << std::endl; + + Runner gemm(args_from_options(options, host_problem_shapes_available)); + Runner gemm_ref(args_from_options(options, host_problem_shapes_available)); + + auto run_fused = [&](){ gemm.run(); }; + auto run_ref_gemm = [&](){ gemm_ref.run(); }; + auto run_activation = [&](){ + do_activation( + block_D_ref.get(), + block_D_ref_gemm.get(), + Quantize ? block_scale.get() : static_cast(nullptr), + options.beta != 0.f ? block_C.get() : static_cast(nullptr), + BiasBroadcast, + offset_col_D.get(), + options.groups, + options.problem_sizes.at(0).m(), // all problems have same M + offset_col_D_host[options.groups], + false); + }; + auto run_unfused = [&](){ run_ref_gemm(); run_activation(); }; + + run_fused(); + CUDA_CHECK(cudaDeviceSynchronize()); + + // Correctness check + bool passed = true; + if (options.verify) { + run_unfused(); + CUDA_CHECK(cudaDeviceSynchronize()); + + passed = verify(options); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + } + + if (options.iterations > 0) + { + auto benchmark = [&](auto name, auto func) + { + BenchmarkResult result = run_benchmark(func, options.warmup, options.iterations); + double avg_tflops = double(options.total_flops()) / result.avg_runtime_ms / 1e9; // FLOP/ms -> TFLOP/s + printf(options.csv ? "%s,%.3f,%.0f\n" : "%20s %20.3f %20.0f\n", + name, result.avg_runtime_ms, avg_tflops); + }; + printf(options.csv ? "%s,%s,%s\n" : "%20s %20s %20s\n", + "Kernel", "Runtime (ms)", "Throughput (Tflop/s)"); + benchmark("Fused", run_fused); + benchmark("Unfused", run_unfused); + benchmark("GEMM only", run_ref_gemm); + } + + return passed; +} + +#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.3 Toolkit to run this example + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) { + std::cerr << "This example requires CUDA 12.3 or newer.\n"; + return EXIT_SUCCESS; + } + + try { + Options options(AlignmentM, AlignmentN, AlignmentK); + cutlass::CommandLine cmd(argc, args); + options.parse(cmd); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.device >= 0) { + CUDA_CHECK(cudaSetDevice(options.device)); + } + else { + CUDA_CHECK(cudaGetDevice(&options.device)); + } + + cudaDeviceProp props; + CUDA_CHECK(cudaGetDeviceProperties(&props, options.device)); + if (props.major != 9 || props.minor != 0) { + std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; + return EXIT_SUCCESS; + } + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + if (!run(options, false)) { + return EXIT_FAILURE; + } +#endif + } + catch (std::exception const& e) { + std::cerr << e.what() << std::endl; + return EXIT_FAILURE; + } + + return EXIT_SUCCESS; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/113_hopper_gemm_activation_fusion/113_hopper_grouped_gemm_fused_gated_act.cu b/examples/113_hopper_gemm_activation_fusion/113_hopper_grouped_gemm_fused_gated_act.cu new file mode 100644 index 000000000..f111acbd7 --- /dev/null +++ b/examples/113_hopper_gemm_activation_fusion/113_hopper_grouped_gemm_fused_gated_act.cu @@ -0,0 +1,694 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief +*/ + +#include +#include +#include +#include +#include + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/fusion/operations.hpp" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "helper.h" +#include "options.hpp" +#include "utils.hpp" +#include "tile_scheduler_group.hpp" +#include "gated_stride.hpp" +#include "gated_builder.hpp" +#include "activation_kernel.cuh" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if 0 +template +using ActivationFn = cutlass::epilogue::thread::ReLu; +#elif 1 +template +using ActivationFn = cutlass::epilogue::thread::SiLu; +#else +template +using ActivationFn = cutlass::epilogue::thread::Identity; +#endif + +bool constexpr IsFp8 = true; // whether to run with fp8 or fp16 input/output +bool constexpr Quantize = true; // whether to quantize output with a per-tensor scale factor +bool constexpr ExactMode = false; // whether to reproduce unfused dual gemm+activation exactly +bool constexpr BiasBroadcast = true; // whether bias is broadcast along columns in each group +bool constexpr Pingpong = true; // whether to use pingpong schedule + +using ProblemShape = Shape; // per group +using GroupProblemShape = cutlass::gemm::GroupProblemShape; + +using GatedProblemShape = decltype(cutlass::sm90_make_gated_shape<0>(ProblemShape{})); +using GatedGroupProblemShape = cutlass::gemm::GroupProblemShape; + +// A matrix configuration +using ElementA = conditional_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = conditional_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Alignment of B matrix in units of elements (up to 16 bytes) + +// C matrix configuration +using ElementC = cutlass::half_t; // Element type for C matrix operand +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C matrix operand +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Alignment of C matrix in units of elements (up to 16 bytes) + +// D matrix configuration +using ElementD = conditional_t; // Element type for D matrix operand +using LayoutD = cutlass::layout::ColumnMajor; // Layout type for D matrix operand +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Alignment of D matrix in units of elements (up to 16 bytes) + +int constexpr AlignmentM = max(make_tuple(cutlass::gemm::detail::is_mn_major_A() ? AlignmentA : 1, + cutlass::gemm::detail::is_m_major_C() ? AlignmentC : 1, + cutlass::gemm::detail::is_m_major_C() ? AlignmentD : 1)); +int constexpr AlignmentN = max(make_tuple(cutlass::gemm::detail::is_mn_major_B() ? AlignmentB : 1, + cutlass::gemm::detail::is_n_major_C() ? AlignmentC : 1, + cutlass::gemm::detail::is_n_major_C() ? AlignmentD : 1)); +int constexpr AlignmentK = max(make_tuple(cutlass::gemm::detail::is_k_major_A() ? AlignmentA : 1, + cutlass::gemm::detail::is_k_major_B() ? AlignmentB : 1)); + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for epilogue compute +using ElementScalar = float; // Element type for scalar values (alpha, beta) +using ElementIntermediate = cutlass::half_t; // Element type of intermediate result between GEMM and bias+activation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using EpiTileShape = cutlass::epilogue::collective::EpilogueTileAuto; // Epilogue sub-tile shape +using ClusterShape = Shape<_1,_2,_1>; // Cluster shape +using TileShapeK = Int<128 * 8 / sizeof_bits::value>; + +using KernelScheduleCooperative = conditional_t(), + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum, + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative>; + +using KernelSchedulePingpong = conditional_t(), + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum, + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong>; + +using KernelSchedule = conditional_t; +using EpilogueSchedule = conditional_t; +using TileShape = conditional_t, Shape<_128,_256,TileShapeK>>; + +// Gated GEMM setup + +using CollectiveEpilogue = typename cutlass::epilogue::collective::Sm90CollectiveBuilderGated< + OperatorClass, + TileShape, ClusterShape, + EpiTileShape, + ElementAccumulator, ElementCompute, ElementScalar, + conditional_t, + ElementC, LayoutC *, AlignmentC, + ElementD, LayoutD *, AlignmentD, + EpilogueSchedule, + ActivationFn, + Quantize + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::Sm90CollectiveBuilderGated< + OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + GatedGroupProblemShape, + CollectiveMainloop, + CollectiveEpilogue, + GroupSchedulerTileShapeDependent +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Reference GEMM setup + +using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + TileShape, ClusterShape, EpiTileShape, + ElementAccumulator, ElementCompute, + void, LayoutC *, AlignmentC, + ElementIntermediate, LayoutD *, AlignmentD, + EpilogueSchedule, + cutlass::epilogue::fusion::ScaledAcc + >::CollectiveOp; + +using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogueRef::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal< + GroupProblemShape, + CollectiveMainloopRef, + CollectiveEpilogueRef +>; + +using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideA = GemmKernel::InternalStrideA; +using StrideB = GemmKernel::InternalStrideB; +using StrideC = GemmKernel::InternalStrideC; +using StrideD = remove_pointer_t; + +using StrideARef = GemmKernelRef::InternalStrideA; +using StrideBRef = GemmKernelRef::InternalStrideB; +using StrideCRef = GemmKernelRef::InternalStrideC; +using StrideDRef = GemmKernelRef::InternalStrideD; + +// Host-side allocations +std::vector problem_shapes_host; +std::vector problem_shapes_ref_host; + +std::vector offset_A; +std::vector offset_B; +std::vector offset_C; +std::vector offset_D; +std::vector offset_D_ref; +std::vector offset_col_D_host; + +std::vector stride_A_host; +std::vector stride_B_host; +std::vector stride_C_host; +std::vector stride_D_host; + +std::vector stride_A_ref_host; +std::vector stride_B_ref_host; +std::vector stride_C_ref_host; +std::vector stride_D_ref_host; + +std::vector alpha_host; +std::vector beta_host; +std::vector scale_host; + +// Device-side allocations +cutlass::DeviceAllocation problem_shapes; +cutlass::DeviceAllocation problem_shapes_ref; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_D_ref_gemm; +cutlass::DeviceAllocation block_D_ref; + +cutlass::DeviceAllocation ptr_A; +cutlass::DeviceAllocation ptr_B; +cutlass::DeviceAllocation ptr_C; +cutlass::DeviceAllocation ptr_D; +cutlass::DeviceAllocation ptr_D_ref; +cutlass::DeviceAllocation offset_col_D; + +cutlass::DeviceAllocation stride_A; +cutlass::DeviceAllocation stride_B; +cutlass::DeviceAllocation stride_C; +cutlass::DeviceAllocation stride_D; + +cutlass::DeviceAllocation stride_A_ref; +cutlass::DeviceAllocation stride_B_ref; +cutlass::DeviceAllocation stride_C_ref; +cutlass::DeviceAllocation stride_D_ref; + +// Note, this is an array of pointers to alpha and beta scaling values per group +cutlass::DeviceAllocation ptr_alpha; +cutlass::DeviceAllocation ptr_beta; +cutlass::DeviceAllocation block_alpha; +cutlass::DeviceAllocation block_beta; +cutlass::DeviceAllocation block_scale; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +using Options = GroupedGemmOptions; + +/// Allocates device-side data +void allocate(const Options &options) { + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + int64_t total_cols_D = 0; + + for (int32_t i = 0; i < options.groups; ++i) { + + cutlass::gemm::GemmCoord const& problem_size = options.problem_sizes[i]; + auto problem_shape_ref = make_shape(problem_size.m(), problem_size.n(), problem_size.k()); + auto [M, N, K] = problem_shape_ref; + auto NC = BiasBroadcast ? 1 : N; + + problem_shapes_host.push_back(cutlass::sm90_make_gated_shape<0>(make_shape(M, N, K))); + problem_shapes_ref_host.push_back(make_shape(M, N, K)); + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + offset_D_ref.push_back(total_elements_D * 2); + offset_col_D_host.push_back(total_cols_D); + + int64_t elements_A = M * K; + int64_t elements_B = K * N; + int64_t elements_C = M * NC; + int64_t elements_D = M/2 * N; + + stride_A_host.push_back(cutlass::sm90_make_gated_packed_stride(StrideA{}, {M, K, 1})); + stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); + stride_C_host.push_back(cutlass::sm90_make_gated_packed_stride(StrideC{}, {M, NC, 1})); + stride_D_host.push_back(cutlass::sm90_make_gated_packed_stride(StrideD{}, {M/2, N, 1})); + + stride_A_ref_host.push_back(cutlass::make_cute_packed_stride(StrideARef{}, {M, K, 1})); + stride_B_ref_host.push_back(cutlass::make_cute_packed_stride(StrideBRef{}, {N, K, 1})); + stride_C_ref_host.push_back(cutlass::make_cute_packed_stride(StrideCRef{}, {M, NC, 1})); + stride_D_ref_host.push_back(cutlass::make_cute_packed_stride(StrideDRef{}, {M, N, 1})); + + if constexpr (BiasBroadcast) { + get<1>(stride_C_host.back()) = 0; + get<1>(stride_C_ref_host.back()) = 0; + } + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + total_cols_D += N; + } + offset_col_D_host.push_back(total_cols_D); + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + block_D_ref_gemm.reset(total_elements_D * 2); + block_D_ref.reset(total_elements_D); + block_alpha.reset(options.groups); + block_beta.reset(options.groups); + block_scale.reset(1); + + problem_shapes.reset(options.groups); + problem_shapes_ref.reset(options.groups); + + ptr_A.reset(options.groups); + ptr_B.reset(options.groups); + ptr_C.reset(options.groups); + ptr_D.reset(options.groups); + ptr_D_ref.reset(options.groups); + ptr_alpha.reset(options.groups); + ptr_beta.reset(options.groups); + + stride_A.reset(options.groups); + stride_B.reset(options.groups); + stride_C.reset(options.groups); + stride_D.reset(options.groups); + + stride_A_ref.reset(options.groups); + stride_B_ref.reset(options.groups); + stride_C_ref.reset(options.groups); + stride_D_ref.reset(options.groups); + offset_col_D.reset(options.groups + 1); +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + // + // Assign pointers + // + + std::vector ptr_A_host(options.groups); + std::vector ptr_B_host(options.groups); + std::vector ptr_C_host(options.groups); + std::vector ptr_D_host(options.groups); + std::vector ptr_D_ref_host(options.groups); + std::vector ptr_alpha_host(options.groups); + std::vector ptr_beta_host(options.groups); + + std::mt19937 rng(2024); + std::uniform_real_distribution alpha_dist(0.5, 2.0); + std::uniform_real_distribution beta_dist(1.0, 1.0); // (0.0, 4.0); + std::uniform_real_distribution scale_dist(0.5, 1.0); + + for (int32_t i = 0; i < options.groups; ++i) { + ptr_A_host[i] = block_A.get() + offset_A[i]; + ptr_B_host[i] = block_B.get() + offset_B[i]; + ptr_C_host[i] = block_C.get() + offset_C[i]; + ptr_D_host[i] = block_D.get() + offset_D[i]; + ptr_D_ref_host[i] = block_D_ref_gemm.get() + offset_D_ref[i]; + alpha_host.push_back(options.alpha == FLT_MAX ? alpha_dist(rng) : options.alpha); + beta_host.push_back(options.beta == FLT_MAX ? beta_dist(rng) : options.beta); + ptr_alpha_host[i] = block_alpha.get() + i; + ptr_beta_host[i] = block_beta.get() + i; + } + scale_host.push_back(scale_dist(rng)); + + problem_shapes.copy_from_host(problem_shapes_host.data()); + problem_shapes_ref.copy_from_host(problem_shapes_ref_host.data()); + + ptr_A.copy_from_host(ptr_A_host.data()); + ptr_B.copy_from_host(ptr_B_host.data()); + ptr_C.copy_from_host(ptr_C_host.data()); + ptr_D.copy_from_host(ptr_D_host.data()); + ptr_D_ref.copy_from_host(ptr_D_ref_host.data()); + ptr_alpha.copy_from_host(ptr_alpha_host.data()); + ptr_beta.copy_from_host(ptr_beta_host.data()); + + stride_A.copy_from_host(stride_A_host.data()); + stride_B.copy_from_host(stride_B_host.data()); + stride_C.copy_from_host(stride_C_host.data()); + stride_D.copy_from_host(stride_D_host.data()); + + stride_A_ref.copy_from_host(stride_A_ref_host.data()); + stride_B_ref.copy_from_host(stride_B_ref_host.data()); + stride_C_ref.copy_from_host(stride_C_ref_host.data()); + stride_D_ref.copy_from_host(stride_D_ref_host.data()); + offset_col_D.copy_from_host(offset_col_D_host.data()); + + block_alpha.copy_from_host(alpha_host.data()); + block_beta.copy_from_host(beta_host.data()); + block_scale.copy_from_host(scale_host.data()); + + cutlass::reference::device::BlockFillRandom(block_A.get(), block_A.size(), 2024, options.dist_a); + cutlass::reference::device::BlockFillRandom(block_B.get(), block_B.size(), 2025, options.dist_b); + cutlass::reference::device::BlockFillRandom(block_C.get(), block_C.size(), 2026, options.dist_c); +} + +template +auto args_from_options_common(const Options &options, bool host_problem_shapes_available) +{ + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = options.sm_count > 0 + ? options.sm_count + : cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + typename GemmT::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args{}; + + if (options.alpha != FLT_MAX) { + fusion_args.alpha = options.alpha; + } + else { + fusion_args.alpha_ptr_array = ptr_alpha.get(); + fusion_args.dAlpha = {{},{},1}; + } + + return make_tuple(fusion_args, hw_info); +} + +template +typename GemmT::Arguments +args_from_options(const Options &options, bool host_problem_shapes_available); + +template <> +Gemm::Arguments +args_from_options(const Options &options, bool host_problem_shapes_available) +{ + auto [fusion_args, hw_info] = args_from_options_common(options, host_problem_shapes_available); + auto problem_shapes_host_ptr = host_problem_shapes_available ? problem_shapes_host.data() : nullptr; + + fusion_args.beta = options.beta; + fusion_args.beta_ptr = block_beta.get(); + fusion_args.scale_ptr = Quantize ? block_scale.get() : nullptr; + fusion_args.ptr_D = ptr_D.get(); + fusion_args.dD = stride_D.get(); + fusion_args.sm_count = hw_info.sm_count; + + using RasterOrderOptions = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90GroupParams::RasterOrderOptions; + RasterOrderOptions raster = [&] { + switch (options.raster) { + case Options::RasterOrderOptions::Heuristic: return RasterOrderOptions::Heuristic; + case Options::RasterOrderOptions::AlongM: return RasterOrderOptions::AlongM; + case Options::RasterOrderOptions::AlongN: return RasterOrderOptions::AlongN; + default: return RasterOrderOptions::Heuristic; + } + }(); + + Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_shapes.get(), problem_shapes_host_ptr}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), {}, {}}, + hw_info, + {options.swizzle, raster} + }; + + return arguments; +} + +template <> +GemmRef::Arguments +args_from_options(const Options &options, bool host_problem_shapes_available) +{ + auto [fusion_args, hw_info] = args_from_options_common(options, host_problem_shapes_available); + auto problem_shapes_host_ptr = host_problem_shapes_available ? problem_shapes_ref_host.data() : nullptr; + + using RasterOrderOptions = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90GroupParams::RasterOrderOptions; + RasterOrderOptions raster = [&] { + switch (options.raster) { + case Options::RasterOrderOptions::Heuristic: return RasterOrderOptions::Heuristic; + case Options::RasterOrderOptions::AlongM: return RasterOrderOptions::AlongM; + case Options::RasterOrderOptions::AlongN: return RasterOrderOptions::AlongN; + default: return RasterOrderOptions::Heuristic; + } + }(); + + GemmRef::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_shapes_ref.get(), problem_shapes_host_ptr}, + {ptr_A.get(), stride_A_ref.get(), ptr_B.get(), stride_B_ref.get()}, + {fusion_args, nullptr, stride_C_ref.get(), ptr_D_ref.get(), stride_D_ref.get()}, + hw_info, + {options.swizzle, raster} + }; + + return arguments; +} + +bool verify(const Options &options) { + bool passed = true; + for (int32_t i = 0; i < options.groups; ++i) { + auto const& problem_size = options.problem_sizes[i]; + auto problem_shape = make_shape(problem_size.m(), problem_size.n(), problem_size.k()); + auto [M, N, K] = problem_shape; + + bool group_passed = false; + if constexpr (ExactMode) { + group_passed = cutlass::reference::device::BlockCompareEqual( + block_D_ref.get() + offset_D[i], block_D.get() + offset_D[i], M/2 * N); + } + else { + group_passed = cutlass::reference::device::BlockCompareRelativelyEqual( + block_D_ref.get() + offset_D[i], block_D.get() + offset_D[i], M/2 * N, ElementD(options.tolerance), ElementD(options.nonzero_floor)); + } + if (!group_passed && options.verbose) { + std::cout << "Group " << i << " failed" << std::endl; + print("D reference: "); print_device_tensor(make_tensor(block_D_ref.get() + offset_D[i], make_shape(M/2, N, 1), GenColMajor{})); + print("D computed: "); print_device_tensor(make_tensor(block_D.get() + offset_D[i], make_shape(M/2, N, 1), GenColMajor{})); + } + passed &= group_passed; + } + return passed; +} + +bool run(Options &options, bool host_problem_shapes_available = true) +{ + // Apply some restrictions on Grouped GEMM options + for (int i = 0; i < options.groups; ++i) { + if (options.problem_sizes[i].m() != options.problem_sizes[0].m()) { + throw std::runtime_error("Variable M problem size is not supported by verification kernel"); + } + } + if (options.beta != FLT_MAX && options.beta != 1.f && options.beta != 0.f) { + throw std::runtime_error("Specifying beta != 0/1 is not supported by verification kernel"); + } + + allocate(options); + initialize(options); + + std::cout << "Groups : " << options.groups << std::endl; + std::cout << "Problem Sizes, Alpha, Beta " << std::endl; + for (int32_t i = 0; i < options.groups; ++i) { + std::cout << " " << shape_string(make_tuple(options.problem_sizes[i].m(), options.problem_sizes[i].n(), options.problem_sizes[i].k())); + std::cout << ", " << alpha_host[i] << ", " << beta_host[i] << std::endl; + } + std::cout << "Data types: " << problem_desc_string() << std::endl; + std::cout << "Activation function: " << activation_func_string() << std::endl; + std::cout << "Kernel schedule: " << kernel_schedule_string() << std::endl; + std::cout << "GEMM tile shape: " << shape_string(TileShape{}) << std::endl; + std::cout << "Epi tile shape: " << epilogue_tile_string(EpiTileShape{}) << std::endl; + std::cout << "Cluster shape: " << shape_string(ClusterShape{}) << std::endl; + std::cout << "Rasterization: " << options.raster_string() << " with a maximum CTA swizzle of " << options.swizzle << std::endl; + std::cout << "Options: Quantize = " << Quantize << ", Exact = " << ExactMode << ", BiasBroadcast = " << BiasBroadcast << std::endl; + + Runner gemm(args_from_options(options, host_problem_shapes_available)); + Runner gemm_ref(args_from_options(options, host_problem_shapes_available)); + + auto run_fused = [&](){ gemm.run(); }; + auto run_ref_gemm = [&](){ gemm_ref.run(); }; + auto run_activation = [&](){ + do_activation( + block_D_ref.get(), + block_D_ref_gemm.get(), + Quantize ? block_scale.get() : static_cast(nullptr), + options.beta != 0.f ? block_C.get() : static_cast(nullptr), + BiasBroadcast, + offset_col_D.get(), + options.groups, + options.problem_sizes.at(0).m() / 2, // all problems have same M + offset_col_D_host[options.groups], + true); + }; + auto run_unfused = [&](){ run_ref_gemm(); run_activation(); }; + + run_fused(); + CUDA_CHECK(cudaDeviceSynchronize()); + + // Correctness check + bool passed = true; + if (options.verify) { + run_unfused(); + CUDA_CHECK(cudaDeviceSynchronize()); + + passed = verify(options); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + } + + if (options.iterations > 0) + { + auto benchmark = [&](auto name, auto func) + { + BenchmarkResult result = run_benchmark(func, options.warmup, options.iterations); + double avg_tflops = double(options.total_flops()) / result.avg_runtime_ms / 1e9; // FLOP/ms -> TFLOP/s + printf(options.csv ? "%s,%.3f,%.0f\n" : "%20s %20.3f %20.0f\n", + name, result.avg_runtime_ms, avg_tflops); + }; + printf(options.csv ? "%s,%s,%s\n" : "%20s %20s %20s\n", + "Kernel", "Runtime (ms)", "Throughput (Tflop/s)"); + benchmark("Fused", run_fused); + benchmark("Unfused", run_unfused); + benchmark("GEMM only", run_ref_gemm); + } + + return passed; +} + +#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.3 Toolkit to run this example + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) { + std::cerr << "This example requires CUDA 12.3 or newer.\n"; + return EXIT_SUCCESS; + } + + try { + Options options(AlignmentM, AlignmentN, AlignmentK); + cutlass::CommandLine cmd(argc, args); + options.parse(cmd); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.device >= 0) { + CUDA_CHECK(cudaSetDevice(options.device)); + } + else { + CUDA_CHECK(cudaGetDevice(&options.device)); + } + + cudaDeviceProp props; + CUDA_CHECK(cudaGetDeviceProperties(&props, options.device)); + if (props.major != 9 || props.minor != 0) { + std::cerr << "This example requires a GPU of NVIDIA's Hopper Architecture (compute capability 90).\n"; + return EXIT_SUCCESS; + } + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + if (!run(options, false)) { + return EXIT_FAILURE; + } +#endif + } + catch (std::exception const& e) { + std::cerr << e.what() << std::endl; + return EXIT_FAILURE; + } + + return EXIT_SUCCESS; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/113_hopper_gemm_activation_fusion/CMakeLists.txt b/examples/113_hopper_gemm_activation_fusion/CMakeLists.txt new file mode 100644 index 000000000..e948fbf64 --- /dev/null +++ b/examples/113_hopper_gemm_activation_fusion/CMakeLists.txt @@ -0,0 +1,57 @@ +# Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +cutlass_example_add_executable( + 113_hopper_gemm_fused_act + 113_hopper_gemm_fused_act.cu + ) + +cutlass_example_add_executable( + 113_hopper_gemm_fused_gated_act + 113_hopper_gemm_fused_gated_act.cu + ) + +cutlass_example_add_executable( + 113_hopper_grouped_gemm_fused_act + 113_hopper_grouped_gemm_fused_act.cu + ) + +cutlass_example_add_executable( + 113_hopper_grouped_gemm_fused_gated_act + 113_hopper_grouped_gemm_fused_gated_act.cu + ) + +add_custom_target( + 113_hopper_gemm_activation_fusion + DEPENDS + 113_hopper_gemm_fused_act + 113_hopper_gemm_fused_gated_act + 113_hopper_grouped_gemm_fused_act + 113_hopper_grouped_gemm_fused_gated_act +) diff --git a/examples/113_hopper_gemm_activation_fusion/activation_kernel.cuh b/examples/113_hopper_gemm_activation_fusion/activation_kernel.cuh new file mode 100644 index 000000000..a368d2b5f --- /dev/null +++ b/examples/113_hopper_gemm_activation_fusion/activation_kernel.cuh @@ -0,0 +1,220 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/numeric_conversion.h" + +template +CUTLASS_DEVICE +constexpr static U +array_convert(T const& input) +{ + using SrcType = typename T::Element; + using DstType = typename U::Element; + static_assert(T::kElements == U::kElements); + using Converter = cutlass::NumericArrayConverter; + return Converter{}(input); +} + +template +CUTLASS_DEVICE +int64_t +lower_bound(T const* values, size_t const size, T const target) { + T const* low = values; + T const* high = values + size; + while (low < high) { + T const* mid = low + (high - low) / 2; + if (*mid < target) { + low = mid + 1; + } + else { + high = mid; + } + } + return static_cast(low - values); +} + +template +CUTLASS_DEVICE +T +load_vec(T const& src) { + constexpr int B = cute::min(128, cute::sizeof_bits_v); + constexpr int N = cute::sizeof_bits_v / B; + using V = cute::uint_bit_t; + V v[N]; + V const* vptr = reinterpret_cast(&src); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + v[i] = vptr[i]; + } + return *reinterpret_cast(&v); +} + +template +CUTLASS_DEVICE +void +store_vec(T& dst, T const& src) { + constexpr int B = cute::min(128, cute::sizeof_bits_v); + constexpr int N = cute::sizeof_bits_v / B; + using V = cute::uint_bit_t; + V v[N]; + V* vptr = reinterpret_cast(&dst); + + *reinterpret_cast(&v) = src; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + vptr[i] = v[i]; + } +} + +template < + int NumThreads, + template class ActFn, + class ElementOutput, + class ElementInput, + class ElementBias, + class ElementCompute +> +CUTLASS_GLOBAL +void +activation_kernel( + ElementOutput* output, + ElementInput const* input, + ElementCompute const* scale_ptr, + ElementBias const* bias_ptr, + bool bias_is_broadcast, + int64_t const* group_col_offset, + int num_groups, + int64_t stride, + bool gated) +{ + int64_t const tid = threadIdx.x; + int64_t const col = blockIdx.x; + if (col >= group_col_offset[num_groups]) + { + return; + } + + size_t gated_size_mul = gated ? 2 : 1; + size_t gated_off = gated ? stride : 0; + + + input = input + col * stride * gated_size_mul; + output = output + col * stride; + + float const quant_scale = scale_ptr ? *scale_ptr : 1.f; + + if (bias_ptr) { + int64_t group = 0; + if (bias_is_broadcast) { + group = lower_bound(group_col_offset, num_groups, (int64_t) col + 1) - 1; + } + size_t bias_offset = (bias_is_broadcast ? group : col) * stride * gated_size_mul; + bias_ptr = bias_ptr + bias_offset; + } + + // Vectorize all loads up to 128 bits + constexpr int64_t VecSize = 128 / cute::max(cutlass::sizeof_bits_v, + cute::max(cutlass::sizeof_bits_v, + cutlass::sizeof_bits_v)); + + using BiasElem = cutlass::Array; + using InputElem = cutlass::Array; + using OutputElem = cutlass::Array; + using ComputeElem = cutlass::Array; + + auto input_vec = reinterpret_cast(input); + auto output_vec = reinterpret_cast(output); + auto bias_ptr_vec = reinterpret_cast(bias_ptr); + + int64_t const num_elems_in_col = stride / VecSize; + int64_t const gated_off_vec = gated_off / VecSize; + + ActFn fn{}; + for (int64_t elem_index = tid; elem_index < num_elems_in_col; elem_index += NumThreads) + { + auto fc1_value = array_convert(load_vec(input_vec[elem_index + gated_off_vec])); + if (bias_ptr_vec) + { + fc1_value = fc1_value + array_convert(load_vec(bias_ptr_vec[elem_index + gated_off_vec])); + } + auto gate_act = fn(fc1_value); + + if (gated) + { + auto gate_mul = array_convert(load_vec(input_vec[elem_index])); + if (bias_ptr_vec) + { + gate_mul = gate_mul + array_convert(load_vec(bias_ptr_vec[elem_index])); + } + gate_act = gate_act * gate_mul; + } + + store_vec(output_vec[elem_index], array_convert(gate_act * quant_scale)); + } +} + +template < + template class ActFn, + class ElementOutput, + class ElementInput, + class ElementBias, + class ElementCompute +> +void do_activation( + ElementOutput* output, + ElementInput const* input, + ElementCompute const* scale, + ElementBias const* bias, + bool bias_is_broadcast, + int64_t const* group_col_offset, + int num_groups, + int64_t stride, + int64_t num_tokens, + bool gated) +{ + int const blocks = num_tokens; + int constexpr threads = 256; + activation_kernel<<>>( + output, + input, + scale, + bias, + bias_is_broadcast, + group_col_offset, + num_groups, + stride, + gated); +} diff --git a/examples/113_hopper_gemm_activation_fusion/gated_builder.hpp b/examples/113_hopper_gemm_activation_fusion/gated_builder.hpp new file mode 100644 index 000000000..03c0e40ab --- /dev/null +++ b/examples/113_hopper_gemm_activation_fusion/gated_builder.hpp @@ -0,0 +1,173 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +// This is a temp fix for circular include issue in CuTe +#include "cute/atom/copy_atom.hpp" + +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "gated_stride.hpp" +#include "sm90_visitor_gated_act.hpp" + +namespace cutlass::detail { + +template +using GatedStride = cute::conditional_t< + cutlass::detail::is_major(), + decltype(replace(InputStride{}, cute::Stride{})), + decltype(replace(InputStride{}, cute::Stride< int64_t,int64_t, int64_t>{})) +>; + +template +using GatedOutputStride = cute::conditional_t< + cutlass::detail::is_major(), + decltype(replace(InputStride{}, cute::Stride{})), + decltype(replace(InputStride{}, cute::Stride< int64_t, int64_t>{})) +>; + +} + +namespace cutlass::gemm::collective { + +template < + class OpClass, + class ElementA, + class GmemLayoutA, + int AlignmentA, + class ElementB, + class GmemLayoutB, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK_, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct Sm90CollectiveBuilderGated { + + using TileShape_MNK = decltype(cutlass::sm90_make_gated_shape<0>(TileShape_MNK_{})); + + using InternalStrideA = cute::remove_pointer_t>; + using GatedInternalStrideA = cutlass::detail::GatedStride<0, InternalStrideA>; + using StrideA = cute::conditional_t::value, GatedInternalStrideA *, GatedInternalStrideA>; + + using StrideB = cutlass::gemm::TagToStrideB_t; + + using CollectiveOp = typename CollectiveBuilder< + arch::Sm90, OpClass, + ElementA, StrideA, AlignmentA, + ElementB, StrideB, AlignmentB, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + StageCountType, KernelScheduleType + >::CollectiveOp; +}; + +} // namespace cutlass::gemm::collective + +namespace cutlass::epilogue::collective { + +template < + class OpClass, + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementScalar, + class ElementIntermediate, + class ElementC, + class GmemLayoutTagC, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class EpilogueSchedule, + template class ActivationFn, + bool Quantize +> +struct Sm90CollectiveBuilderGated { + + static constexpr bool IsPtrArray = platform::is_pointer::value; + + using EpilogueTile_MN = + decltype(detail::sm90_compute_tile_shape_or_override()); + + // Factor out a 8x2 sub-tile in TileM + using GatedTileShape_MNK = decltype(cutlass::sm90_make_gated_shape<0>(TileShape_MNK{})); + using GatedEpilogueTile_MN = decltype(cutlass::sm90_make_gated_shape<0>(EpilogueTile_MN{})); + + using InternalStrideC = cute::remove_pointer_t>; + using GatedInternalStrideC = cutlass::detail::GatedStride<0, InternalStrideC>; + using StrideC = cute::conditional_t; + + using InternalStrideD = cute::remove_pointer_t>; + using GatedInternalStrideD = cutlass::detail::GatedStride<0, InternalStrideD>; + using StrideD = cute::conditional_t; + + // Gated kernel uses Aux output instead of D due to change in tensor shape + using InternalStrideDAux = cutlass::detail::GatedOutputStride<0, InternalStrideD>; + using StrideDAux = cute::conditional_t; + + using FusionOp = cutlass::epilogue::fusion::LinCombGatedActFunc< + Quantize, // Quantize + ActivationFn, // ActivationFn + StrideDAux, // GmemLayoutTagOutput + ElementD, // ElementOutput + ElementCompute, // ElementCompute + ElementC, // ElementSource + ElementScalar, // ElementScalar + ElementIntermediate, // ElementIntermediate + AlignmentD // Alignment + >; + + using CollectiveOp = typename CollectiveBuilder< + arch::Sm90, + OpClass, + GatedTileShape_MNK, + ClusterShape_MNK, + GatedEpilogueTile_MN, + ElementAccumulator, + ElementCompute, + ElementC, + StrideC, + AlignmentC, + void, // output through AuxStore + StrideD, + AlignmentD, + EpilogueSchedule, + FusionOp + >::CollectiveOp; +}; + +} // namespace cutlass::epilogue::collective diff --git a/examples/113_hopper_gemm_activation_fusion/gated_stride.hpp b/examples/113_hopper_gemm_activation_fusion/gated_stride.hpp new file mode 100644 index 000000000..e3f1600d9 --- /dev/null +++ b/examples/113_hopper_gemm_activation_fusion/gated_stride.hpp @@ -0,0 +1,157 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/layout.hpp" +#include "cute/algorithm/tuple_algorithms.hpp" +#include "cutlass/detail/layout.hpp" + +/** + * Convenience functions for computing input/output strides for sm90 gated activation kernel + */ + +namespace cutlass { + +template +CUTLASS_HOST_DEVICE +auto +sm90_make_gated_shape(InputShape const& shape) { + using namespace cute; + using Tiler = Shape<_8,_2>; + return replace(shape, append(Tiler{}, shape_div(get(shape), Tiler{}))); +} + +template +CUTLASS_HOST_DEVICE +auto +sm90_make_gated_output_shape(InputShape const& shape) { + using namespace cute; + using Tiler = Shape<_8>; + return replace(shape, append(Tiler{}, shape_div(get(shape), Tiler{}))); +} + +// K-major gated gemm stride +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<1>, StrideIntT> +sm90_make_gated_packed_stride(cute::Stride, cute::Int<1>, StrideIntT>, cute::Shape shape_MKL) { + using namespace cute; + static_assert(std::is_integral_v, "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto shape = sm90_make_gated_shape<0>(cute::transform(shape_MKL, [](auto s){ return static_cast(s); })); + auto stride = compact_order(shape, Step,_0,_4>{}); + return stride; +} + +// K-major gated gemm output stride +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<1>, StrideIntT> +sm90_make_gated_packed_stride(cute::Stride, cute::Int<1>, StrideIntT>, cute::Shape shape_MKL) { + using namespace cute; + static_assert(std::is_integral_v, "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto shape = sm90_make_gated_output_shape<0>(cute::transform(shape_MKL, [](auto s){ return static_cast(s); })); + auto stride = compact_order(shape, Step,_0,_3>{}); + return stride; +} + +// K-major grouped gemm stride +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<1>, cute::Int<0>> +sm90_make_gated_packed_stride(cute::Stride, cute::Int<1>, cute::Int<0>>, cute::Shape shape_MKL) { + using namespace cute; + static_assert(std::is_integral_v, "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto shape = sm90_make_gated_shape<0>(cute::transform(shape_MKL, [](auto s){ return static_cast(s); })); + auto stride = append(compact_order(take<0,2>(shape), Step,_0>{}), Int<0>{}); + return stride; +} + +// K-major grouped gemm output stride +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<1>, cute::Int<0>> +sm90_make_gated_packed_stride(cute::Stride, cute::Int<1>, cute::Int<0>>, cute::Shape shape_MKL) { + using namespace cute; + static_assert(std::is_integral_v, "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto shape = sm90_make_gated_output_shape<0>(cute::transform(shape_MKL, [](auto s){ return static_cast(s); })); + auto stride = append(compact_order(take<0,2>(shape), Step,_0>{}), Int<0>{}); + return stride; +} + +// MN-major gated gemm stride +template +CUTLASS_HOST_DEVICE +cute::Stride,StrideIntT,cute::Int<8>>, StrideIntT, StrideIntT> +sm90_make_gated_packed_stride(cute::Stride,StrideIntT,cute::Int<8>>, StrideIntT, StrideIntT>, cute::Shape shape_MKL) { + using namespace cute; + static_assert(std::is_integral_v, "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto shape = sm90_make_gated_shape<0>(cute::transform(shape_MKL, [](auto s){ return static_cast(s); })); + auto stride = compact_order(shape, Step,_3,_4>{}); + return stride; +} + +// MN-major gated gemm output stride +template +CUTLASS_HOST_DEVICE +cute::Stride,cute::Int<8>>, StrideIntT, StrideIntT> +sm90_make_gated_packed_stride(cute::Stride,cute::Int<8>>, StrideIntT, StrideIntT>, cute::Shape shape_MKL) { + using namespace cute; + static_assert(std::is_integral_v, "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto shape = sm90_make_gated_output_shape<0>(cute::transform(shape_MKL, [](auto s){ return static_cast(s); })); + auto stride = compact_order(shape, Step,_2,_3>{}); + return stride; +} + +// MN-major gated grouped gemm stride +template +CUTLASS_HOST_DEVICE +cute::Stride,StrideIntT,cute::Int<8>>, StrideIntT, cute::Int<0>> +sm90_make_gated_packed_stride(cute::Stride,StrideIntT,cute::Int<8>>, StrideIntT, cute::Int<0>>, cute::Shape shape_MKL) { + using namespace cute; + static_assert(std::is_integral_v, "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto shape = sm90_make_gated_shape<0>(cute::transform(shape_MKL, [](auto s){ return static_cast(s); })); + auto stride = append(compact_order(take<0,2>(shape), Step,_3>{}), Int<0>{}); + return stride; +} + +// MN-major gated grouped gemm output stride +template +CUTLASS_HOST_DEVICE +cute::Stride,cute::Int<8>>, StrideIntT, cute::Int<0>> +sm90_make_gated_packed_stride(cute::Stride,cute::Int<8>>, StrideIntT, cute::Int<0>>, cute::Shape shape_MKL) { + using namespace cute; + static_assert(std::is_integral_v, "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto shape = sm90_make_gated_output_shape<0>(cute::transform(shape_MKL, [](auto s){ return static_cast(s); })); + auto stride = append(compact_order(take<0,2>(shape), Step,_2>{}), Int<0>{}); + return stride; +} + +} // namespace cutlass diff --git a/examples/113_hopper_gemm_activation_fusion/options.hpp b/examples/113_hopper_gemm_activation_fusion/options.hpp new file mode 100644 index 000000000..6679a1d18 --- /dev/null +++ b/examples/113_hopper_gemm_activation_fusion/options.hpp @@ -0,0 +1,380 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/gemm_coord.h" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#define OPTIONS_ERROR(...) \ +do { \ + std::stringstream ss; \ + ss << __VA_ARGS__; \ + throw std::runtime_error(ss.str()); \ +} \ +while (false) + +// Command line options parsing +template +struct GemmOptionsBase { + + using RasterOrderOptions = RasterOrderOptions_; + + bool help = false; + int iterations = 100; + int warmup = 100; + RasterOrderOptions raster = RasterOrderOptions::Heuristic; + int swizzle = 1; + int sm_count = 0; + int device = -1; + float tolerance = 2e-1f; + float nonzero_floor = 3e-1f; + bool verify = true; + bool verbose = false; + bool csv = false; + std::string program_path; + + cutlass::Distribution dist_a; + cutlass::Distribution dist_b; + cutlass::Distribution dist_c; + + std::string raster_string() const { + switch (raster) { + case RasterOrderOptions::Heuristic: return "Heuristic"; + case RasterOrderOptions::AlongM: return "AlongM"; + case RasterOrderOptions::AlongN: return "AlongN"; + } + return "Unknown"; + } + + static cutlass::Distribution + get_distribution( + cutlass::CommandLine const& cmd, + char const* arg_name) { + + struct { + const char *label; + cutlass::Distribution::Kind kind; + } distribution_kinds[] = { + {"uniform", cutlass::Distribution::Uniform}, + {"gaussian", cutlass::Distribution::Gaussian}, + {"sequential", cutlass::Distribution::Sequential}, + {0, cutlass::Distribution::Invalid} + }; + + cutlass::Distribution dist; + + struct { + char const *label; + double *member; + } members[] = { + {"min", &dist.uniform.min}, + {"max", &dist.uniform.max}, + {"mean", &dist.gaussian.mean}, + {"stddev", &dist.gaussian.stddev}, + {"pnzA", &dist.gaussian.pnzA}, + {"pnzB", &dist.gaussian.pnzB}, + {"pnzC", &dist.gaussian.pnzC}, + {"start", &dist.sequential.start}, + {"delta", &dist.sequential.delta}, + {0, 0} + }; + + using KeyValueVector = std::vector>; + + KeyValueVector values; + cmd.get_cmd_line_argument_pairs(arg_name, values); + + // The parser expects the first token to be a string identifying the distribution type. + auto it = values.begin(); + if (it != values.end()) { + for (int i = 0; distribution_kinds[i].label; ++i) { + if (it->first == distribution_kinds[i].label) { + dist.kind = distribution_kinds[i].kind; + break; + } + } + ++it; + } + + // Default initialization + switch (dist.kind) { + case cutlass::Distribution::Uniform: + dist.set_uniform(-1/*min*/, 1/*max*/, -1/*int_scale*/); + break; + case cutlass::Distribution::Gaussian: + dist.set_gaussian(0/*mean*/, 1/*stddev*/, -1/*int_scale*/); + break; + case cutlass::Distribution::Sequential: + dist.set_sequential(0/*start*/, 1/*delta*/, -1/*int_scale*/); + break; + default: + dist.set_gaussian(0/*mean*/, 1/*stddev*/, -1/*int_scale*/); + return dist; + } + + // Subsequent key-value pairs update the named field of the distribution struct. + for (; it != values.end(); ++it) { + // Integer scaling factor - if < 0, no integer rounding is performed. + if ((it->first.compare("scale") == 0) && !it->second.empty()) { + std::stringstream ss; + ss << it->second; + ss >> dist.int_scale; + continue; // next token + } + + // Casts as integer without scaling + if (it->first.compare("integer") == 0) { + dist.int_scale = 0; + continue; // next token + } + + // initialize other members + for (int m = 0; members[m].label; ++m) { + if (it->first == members[m].label && !it->second.empty()) { + std::stringstream ss; + ss << it->second; + ss >> *(members[m].member); + } + } + } + + return dist; + } + + // Parses the command line + void parse(cutlass::CommandLine const& cmd) { + program_path = cmd.program_path; + + cmd.get_cmd_line_argument("help", help); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("warmup", warmup); + cmd.get_cmd_line_argument("sms", sm_count); + cmd.get_cmd_line_argument("device", device); + cmd.get_cmd_line_argument("tolerance", tolerance); + cmd.get_cmd_line_argument("nonzero_floor", nonzero_floor); + cmd.get_cmd_line_argument("verify", verify); + cmd.get_cmd_line_argument("verbose", verbose); + cmd.get_cmd_line_argument("csv", csv); + + char raster_char = 'H'; + cmd.get_cmd_line_argument("raster", raster_char); + + if (std::toupper(raster_char) == 'N') { + raster = RasterOrderOptions::AlongN; + } + else if (std::toupper(raster_char) == 'M') { + raster = RasterOrderOptions::AlongM; + } + else if (std::toupper(raster_char) == 'H') { + raster = RasterOrderOptions::Heuristic; + } + else { + OPTIONS_ERROR("Invalid raster order: " << raster_char); + } + + cmd.get_cmd_line_argument("swizzle", swizzle, 1); + + dist_a = get_distribution(cmd, "adist"); + dist_b = get_distribution(cmd, "bdist"); + dist_c = get_distribution(cmd, "cdist"); + } +}; + +// Command line options parsing +struct GroupedGemmOptions : GemmOptionsBase::RasterOrderOptions> { + + using Base = GemmOptionsBase::RasterOrderOptions>; + + float alpha = FLT_MAX, beta = FLT_MAX; + int groups = 10; + std::vector problem_sizes; + + int align_m = 1; + int align_n = 1; + int align_k = 1; + + GroupedGemmOptions( + int align_m = 1, + int align_n = 1, + int align_k = 1) + : Base(), + align_m(align_m), + align_n(align_n), + align_k(align_k) {} + + // Parses the command line + void parse(cutlass::CommandLine const& cmd) { + Base::parse(cmd); + cmd.get_cmd_line_argument("alpha", alpha); + cmd.get_cmd_line_argument("beta", beta); + cmd.get_cmd_line_argument("groups", groups); + randomize_problems(cmd); + } + + template + static T + read_value(std::string const& s, T default_ = {}) { + std::istringstream ss(s); + T val; + ss >> val; + if (ss.fail()) { + val = default_; + } + return val; + } + + // Read from command line a comma-separated list of ranges of the form (:)[,...]. + // If only value specified, set =. + // If arg_name no present on command line, use default_ as the only value. + // If num_ranges >= 0, returns exactly num_ranges ranges, truncating the list or extending it + // by repeating the last value, if necessary. + template + static std::vector> + get_int_ranges( + cutlass::CommandLine const& cmd, + char const* arg_name, + std::pair default_, + int num_ranges = -1) { + std::vector> input; + cmd.get_cmd_line_argument_pairs(arg_name, input); + + std::vector> result; + std::transform(input.begin(), input.end(), std::back_inserter(result), + [](auto const& range_str) { + T minval = read_value(range_str.first); + T maxval = read_value(range_str.second, minval); + return std::make_pair(minval, maxval); + }); + + if (result.empty()) { + result.push_back(default_); + } + + if (num_ranges >= 0) { + auto last = result.back(); + for (int i = static_cast(result.size()); i < num_ranges; ++i) { + result.push_back(last); + } + result.resize(num_ranges); + } + + return result; + } + + void randomize_problems(cutlass::CommandLine const& cmd) { + + auto m_ranges = get_int_ranges(cmd, "m", {10240, 10240}, groups); // Fixed "inter_size" in MoE + auto n_ranges = get_int_ranges(cmd, "n", { 1024, 2048}, groups); // Variable "token per expert" dimension in MoE, should always vary to test correctness + auto k_ranges = get_int_ranges(cmd, "k", { 8192, 8192}, groups); // Fixed "hidden_dim" in MoE + + auto random_size = [](auto vmin, auto vmax, auto align, int group, char const * name) { + auto avmin = (vmin + align - 1) / align; + auto avmax = vmax / align; + if (avmax - avmin < 0) { + OPTIONS_ERROR("Group " << group << ": problem size " << name << " range=[" << vmin << "," << vmax << "], must contain at least one multiple of " << align); + } + return align * ((rand() % (avmax - avmin + 1)) + avmin); + }; + + auto check_size = [](auto value, auto align, int group, char const * name) { + if (value <= 0) { + OPTIONS_ERROR("Group " << group << ": problem size " << name << "=" << value << ", must be positive"); + } + if (value % align != 0) { + OPTIONS_ERROR("Group " << group << ": problem size " << name << "=" << value << ", must be a multiple of " << align); + } + }; + + problem_sizes.reserve(groups); + for (int i = 0; i < groups; ++i) { + int M = random_size(m_ranges[i].first, m_ranges[i].second, align_m, i, "M"); + int N = random_size(n_ranges[i].first, n_ranges[i].second, align_n, i, "N"); + int K = random_size(k_ranges[i].first, k_ranges[i].second, align_k, i, "K"); + check_size(M, align_m, i, "M"); + check_size(N, align_n, i, "N"); + check_size(K, align_k, i, "K"); + problem_sizes.emplace_back(M, N, K); + } + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << program_path << "\n" + "\n" + " Hopper Grouped Dual GEMM using with fused activation.\n" + "\n" + "Options:\n" + "\n" + " --help Display this usage statement\n" + " --m=(:)[,...] Set the M range of the GEMM for each group (last range used for remaining groups)\n" + " --n=(:)[,...] Set the N range of the GEMM for each group (last range used for remaining groups)\n" + " --k=(:)[,...] Set the K range of the GEMM for each group (last range used for remaining groups)\n" + " --groups= Set the number of individual GEMM problems for Grouped GEMM\n" + " --alpha= Epilogue scalar alpha\n" + " --beta= Epilogue scalar beta\n" + " --raster= CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n" + " --swizzle= CTA Rasterization swizzle\n" + " --warmup= Number of warmup iterations to perform\n" + " --iterations= Number of profiling iterations to perform\n" + " --verify= Verification (correctness check) on/off\n" + " --verbose Verbose mode (output detailed verification result)\n" + " --sms Number of SMs to run the GEMMs on\n" + " --device Device index\n" + "\n" + "Any problem size range can be specifed as a pair : or a single integer (=) for a fixed size.\n" + "\n" + "Example:\n" + << program_path << " --m=5120 --n=1024:2048 --k=4096 --groups=10 --alpha=1 --beta=0\n"; + + return out; + } + + /// Compute number of flops + uint64_t total_flops() const { + // Two flops per multiply-add + return 2 * std::accumulate(problem_sizes.begin(), problem_sizes.end(), 0ULL, + [](uint64_t acc, auto p){ return acc + p.product(); }); + } +}; + +#undef OPTIONS_ERROR diff --git a/examples/113_hopper_gemm_activation_fusion/sm90_lin_comb_elt_act_scaled.hpp b/examples/113_hopper_gemm_activation_fusion/sm90_lin_comb_elt_act_scaled.hpp new file mode 100644 index 000000000..6cd9a0852 --- /dev/null +++ b/examples/113_hopper_gemm_activation_fusion/sm90_lin_comb_elt_act_scaled.hpp @@ -0,0 +1,298 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree node for gated activation function +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/workspace.h" + +#include "cute/tensor.hpp" + +#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" // Sm90EVT +#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp" // Sm90ScalarBroadcastPtrArray +#include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp" // Sm90AuxArrayStore +#include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" // Sm90Compute + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +template< + bool PtrArray, + class Element, + class Stride> +using Sm90ScalarBroadcastSelector = cute::conditional_t, + Sm90ScalarBroadcast +>; + +// D = activation(alpha * acc + beta * C) +template< + bool DoScale, + template class ActivationFn_, + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + class ElementIntermediate = ElementOutput, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +struct AccCastLinCombEltActScale + : LinearCombination { + using ActivationFn = ActivationFn_; + static constexpr bool IsEltActSupported = true; +}; + +// D = activation(alpha * acc + beta * C)) +template< + bool PtrArray, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + class ElementIntermediate = ElementOutput, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90AccCastLinCombEltAct = + Sm90EVT, // activation(beta * C + (alpha * acc)) + // This is same as Sm90LinearCombination except performs a roundrip cast to ElementIntermediate + // after accumulator scaling but before adding source (bias) + Sm90EVT, // beta * C + (alpha * acc) + Sm90ScalarBroadcastSelector>, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + Sm90ScalarBroadcastSelector>, // alpha + Sm90AccFetch // acc + > + > + >; + +// D = scale * activation(alpha * acc + beta * C) +template< + bool PtrArray, + bool DoScale, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + class ElementIntermediate = ElementOutput, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90AccCastLinCombEltActScale = + cute::conditional_t, + Sm90ScalarBroadcastSelector>, + Sm90AccCastLinCombEltAct + >, + Sm90AccCastLinCombEltAct + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + bool DoScale, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementSource, + class ElementScalar, + class ElementIntermediate, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::AccCastLinCombEltActScale, + CtaTileShapeMNK, + EpilogueTile +> : Sm90AccCastLinCombEltActScale { + + using Impl = Sm90AccCastLinCombEltActScale; + using Operation = fusion::AccCastLinCombEltActScale; + + struct Arguments { + + using StrideAlpha = Stride<_0,_0,int64_t>; + ElementScalar alpha = ElementScalar(1); + ElementScalar const* alpha_ptr{}; + StrideAlpha dAlpha{}; + + using StrideBeta = Stride<_0,_0,int64_t>; + ElementScalar beta = ElementScalar(0); + ElementScalar const* beta_ptr{}; + StrideBeta dBeta{}; + + using StrideScale = Stride<_0,_0,int64_t>; + ElementScalar scale = ElementScalar(1); + ElementScalar const* scale_ptr{}; + StrideScale dScale{}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + + using SubImpl = Sm90AccCastLinCombEltAct; + typename SubImpl::Arguments actlincomb_args + { // unary op: activation(beta * C + (alpha * acc)) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, + {} // ternary args : multiply_add + }, + activation // unary args: activation + }; + + return [&]() { + if constexpr (DoScale) { + return typename Impl::Arguments + { // binary op : scale * (actlincomb) + {{scale}, {scale_ptr}, {dScale}}, // leaf args : scale + actlincomb_args, // leaf_args : actlincomb + {} // leaf args : multiplies + }; + } + else { + return actlincomb_args; + } + }(); + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups, + bool DoScale, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementSource, + class ElementScalar, + class ElementIntermediate, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90PtrArrayTmaWarpSpecialized, + fusion::AccCastLinCombEltActScale, + CtaTileShapeMNK, + EpilogueTile +> : Sm90AccCastLinCombEltActScale { + + using Impl = Sm90AccCastLinCombEltActScale; + using Operation = fusion::AccCastLinCombEltActScale; + + struct Arguments { + + using StrideAlpha = Stride<_0,_0,int64_t>; + ElementScalar alpha = ElementScalar(1); + ElementScalar const* alpha_ptr{}; + ElementScalar const* const* alpha_ptr_array{}; + StrideAlpha dAlpha{}; + + using StrideBeta = Stride<_0,_0,int64_t>; + ElementScalar beta = ElementScalar(0); + ElementScalar const* beta_ptr{}; + ElementScalar const* const* beta_ptr_array{}; + StrideBeta dBeta{}; + + using StrideScale = Stride<_0,_0,int64_t>; + ElementScalar scale = ElementScalar(1); + ElementScalar const* scale_ptr{}; + ElementScalar const* const* scale_ptr_array{}; + StrideScale dScale{}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + + using SubImpl = Sm90AccCastLinCombEltAct; + typename SubImpl::Arguments actlincomb_args + { // unary op: activation(beta * C + (alpha * acc)) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, + {} // ternary args : multiply_add + }, + activation // unary args: activation + }; + + return [&]() { + if constexpr (DoScale) { + return typename Impl::Arguments + { // binary op : scale * (actlincomb) + {{scale}, {scale_ptr}, {scale_ptr_array}, {dScale}}, // leaf args : scale + actlincomb_args, // leaf args : actlincomb + {} // leaf args : multiplies + }; + } + else { + return actlincomb_args; + } + }(); + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +} // namespace cutlass::epilogue::fusion diff --git a/examples/113_hopper_gemm_activation_fusion/sm90_visitor_gated_act.hpp b/examples/113_hopper_gemm_activation_fusion/sm90_visitor_gated_act.hpp new file mode 100644 index 000000000..31dd28773 --- /dev/null +++ b/examples/113_hopper_gemm_activation_fusion/sm90_visitor_gated_act.hpp @@ -0,0 +1,614 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree node for gated activation function +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/workspace.h" + +#include "cute/tensor.hpp" + +#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" // Sm90EVT +#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp" // Sm90ScalarBroadcast(PtrArray) +#include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp" // Sm90Aux(Array)Store +#include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" // Sm90Compute + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +template< + bool PtrArray, + class Element, + class Stride> +using Sm90ScalarBroadcastSelector = cute::conditional_t, + Sm90ScalarBroadcast +>; + +template< + bool PtrArray, + bool Quantize, + template class ActivationFn, + int Stages, + int NumEpilogueWarpGroups, + class EpilogueTile, + class StrideMNL, + class SmemLayoutAtom, + class CopyOpR2S, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest> +struct Sm90GatedActivation +{ + // Transparently handle PtrArray/GroupGemm case by using a dummy shape on host + template + CUTLASS_HOST_DEVICE + static constexpr auto + get_problem_shape(ProblemShape const& problem_shape) { + if constexpr (PtrArray) { + return typename ProblemShape::UnderlyingProblemShape{}; + } + else { + return problem_shape; + } + } + + // Convert input problem shape [(M,2),N,K,L] to output problem shape [M,N,K,L] + template + CUTLASS_HOST_DEVICE + static constexpr auto + to_output_shape(Shape const& shape) { + using namespace cute; + static_assert(CUTE_STATIC_V(rank<0>(shape)) == 3, "Input shape/coord must have a rank-3 M-mode"); + auto M = remove<1>(get<0>(shape)); + return replace<0>(shape, M); + } + + using EpilogueTileOut = decltype(to_output_shape(EpilogueTile{})); + + // Define sub-EVTs below that will be invoked manually + // Cannot compose them using normal EVT structure due to gated activation logic: + // 1. Compute EVT (activation) is only visited on "bottom" half of the values + // 2. Store EVT is visited after multiplying gating and activation values, + // which needs access to the whole epilogue tile, i.e. in reduce() + + using ComputeOp = Sm90Compute; + using ComputeEVT = Sm90EVT; // leaf input slot + + using StoreOp = cute::conditional_t, + Sm90AuxStore + >; + using ScaleOp = Sm90EVT, // scale op + Sm90ScalarBroadcastSelector>, // scale factor broadcast + Sm90AccFetch>; // leaf input slot + using StoreEVT = Sm90EVT>; + + // Delegate most operations to generic Sm90Visitor even though we don't inherit from it + using Impl = Sm90EVT; + + using SharedStorage = typename Impl::SharedStorage; + using Arguments = typename Impl::Arguments; + using Params = typename Impl::Params; + template + using TensorMaps = typename Impl::template TensorMaps; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return Impl::to_underlying_arguments(to_output_shape(get_problem_shape(problem_shape)), args, workspace); + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + // unlike other host APIs, can_implement gets passed underlying problem shape for Grouped Gemm cases + return Impl::can_implement(to_output_shape(problem_shape), args); + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return Impl::get_workspace_size(to_output_shape(get_problem_shape(problem_shape)), args); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return Impl::initialize_workspace(to_output_shape(get_problem_shape(problem_shape)), args, workspace, stream); + } + + CUTLASS_HOST_DEVICE + Sm90GatedActivation() : impl() { } + + CUTLASS_HOST_DEVICE + Sm90GatedActivation(Params const& params, SharedStorage const& shared_storage) + : impl(params, shared_storage) { } + + Impl impl; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return impl.is_producer_load_needed(); + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return impl.is_C_load_needed(); + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return impl.get_producer_load_callbacks(args); + } + + template < + class CallbacksImpl, + class CrdTensor + > + struct ConsumerStoreCallbacks : CallbacksImpl { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + CallbacksImpl impl, + CrdTensor tRS_cD) + : CallbacksImpl(impl), + tRS_cD(tRS_cD) {} + + using CallbacksImpl::callbacks_tuple; + CrdTensor tRS_cD; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + using namespace cute; + + static_assert(FragmentSize % 4 == 0, "Fragment size is too small"); + using FrgOutput = Array; + + // This splitting relies on details of WGMMA accumulator register layout + Tensor input = flat_divide(make_tensor(frg_input.data(), Int{}), Layout>{}); + Tensor input_val = make_tensor_like(input(_,0,_)); + Tensor input_act = make_tensor_like(input(_,1,_)); + copy(input(_,0,_), input_val); + copy(input(_,1,_), input_act); + + FrgOutput const& frg_input_val = recast(input_val)(0); + FrgOutput const& frg_input_act = recast(input_act)(0); + + // store(gemm0 * act(gemm1)) + FrgOutput frg_output_act = get<0>(callbacks_tuple).visit(frg_input_act, epi_v, epi_m, epi_n); + FrgOutput frg_output = frg_input_val * frg_output_act; + get<1>(callbacks_tuple).visit(frg_output, epi_v, epi_m, epi_n); + + return frg_input; + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + using namespace cute; + + // Transform TV layout of the tiled copy by removing every other group of 8 rows. + // Note: assumes by-mode tilers that are bijective here - not necessarily the case in general! + + auto tiler_mn = typename decltype(args.tiled_copy)::Tiler_MN{}; + auto layout_tv = typename decltype(args.tiled_copy)::TiledLayout_TV{}; + auto [tiler_m, tiler_n] = tiler_mn; + int constexpr TileM = CUTE_STATIC_V(size(tiler_m)); + int constexpr TileN = CUTE_STATIC_V(size(tiler_n)); + auto row_selector = Layout>, Stride<_1,_16>>{}; // select every other group of 8 rows + auto col_selector = Layout>{}; // select all columns + + auto tiler_mn_out = + make_tile( + right_inverse( + make_layout_like( + composition( + right_inverse(tiler_m), + row_selector + ) + ) + ), + tiler_n + ); + + auto layout_tv_out = + right_inverse( // t,v/2 -> copy m/2,n + composition( // copy m/2,n -> t,v/2 + make_layout_like( // real m/2,n -> t,v/2 + composition( // real m/2,n -> t,v/2 + composition( // real m,n -> t,v + right_inverse(layout_tv).with_shape(shape(tiler_mn)), // copy m,n -> t,v + make_tile(right_inverse(tiler_m), right_inverse(tiler_n)) // real m,n -> copy m,n + ), + make_tile(row_selector, col_selector) // real m,n -> real m/2,n + ) + ), + tiler_mn_out + ) + ).with_shape(make_shape(size<0>(layout_tv), size<1>(layout_tv)/_2{})); // t,v/2 -> copy m/2,n + + auto tiled_copy = TiledCopy, decltype(layout_tv_out), decltype(tiler_mn_out)>{}; + + auto args_impl = ConsumerStoreArgs{ + to_output_shape(args.problem_shape_mnkl), + to_output_shape(args.tile_shape_mnk), + to_output_shape(args.tile_coord_mnkl), + args.tiled_mma, + EpilogueTileOut{}, + tiled_copy, + args.cD, + args.residue_cD, + args.tCcD, + args.residue_tCcD, + args.tCrC, + args.thread_idx + }; + + auto cst_impl = impl.get_consumer_store_callbacks(args_impl); + + return ConsumerStoreCallbacks( + cst_impl, + args.tCcD); + } + + template + struct TensorMapCallbacks : CallbacksImpl { + + CUTLASS_DEVICE + TensorMapCallbacks(CallbacksImpl&& impl) : CallbacksImpl(cute::move(impl)) {} + + template + CUTLASS_DEVICE + void + perform_update( + TensorMaps tensormaps, + ProblemShape_MNKL problem_shape_mnkl, + int32_t next_batch, + int32_t warp_group_idx) + { + CallbacksImpl::perform_update(tensormaps, to_output_shape(problem_shape_mnkl), next_batch, warp_group_idx); + } + }; + + template + CUTLASS_DEVICE constexpr auto + get_tensormap_callbacks() { + auto tmap_callbacks = impl.template get_tensormap_callbacks(); + return TensorMapCallbacks(cute::move(tmap_callbacks)); + } +}; + +template< + bool Quantize, // whether to quantize output with a per-tensor scale factor + template class ActivationFn, + class GmemLayoutTagOutput, + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + class ElementIntermediate = ElementOutput, + int Alignment = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +struct LinCombGatedActFunc + : LinearCombination { + using ElementAux = ElementOutput; + using GmemLayoutTagAux = GmemLayoutTagOutput; + static constexpr int AlignmentAux = Alignment; + static constexpr bool IsAuxOutSupported = true; +}; + +template< + bool PtrArray, + bool Quantize, + template class ActivationFn, + int Stages, + int NumEpilogueWarpGroups, + class EpilogueTile, + class StrideMNL, + class SmemLayoutAtom, + class CopyOpR2S, + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + class ElementIntermediate = ElementOutput, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombGatedActFunc = + Sm90EVT, // store(x(0) * f(x(1) * scale)) + // This is same as Sm90LinearCombinationPtrArray except it performs a roundrip cast to ElementIntermediate + // after accumulator scaling but before adding source (bias), which emulates precision of unfused path + Sm90EVT, // beta * C + (alpha * acc) + Sm90ScalarBroadcastSelector>, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + Sm90ScalarBroadcastSelector>, // alpha + Sm90AccFetch // acc + > + > + >; + + template < + // DispatchPolicy args + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + // Fusion op args + // Gated act + quantization args + bool Quantize, + template class ActivationFn, + // Store args + class GmemLayoutTagOutput, + // Element types + class ElementOutput, + class ElementCompute, + class ElementSource, + class ElementScalar, + class ElementIntermediate, + int Alignment, + FloatRoundStyle RoundStyle, + // Tile shape args + class CtaTileShapeMNK, + class EpilogueTile, + // Aux store args + class SmemLayoutAtom, + class CopyOpR2S +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombGatedActFunc, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpR2S +> : Sm90LinCombGatedActFunc, + SmemLayoutAtom, CopyOpR2S, ElementOutput, ElementCompute, ElementSource, ElementScalar, ElementIntermediate, RoundStyle> { + + using Impl = Sm90LinCombGatedActFunc, + SmemLayoutAtom, CopyOpR2S, ElementOutput, ElementCompute, ElementSource, ElementScalar, ElementIntermediate, RoundStyle>; + using Operation = fusion::LinCombGatedActFunc; + + struct Arguments { + using StrideAlpha = Stride<_0,_0,int64_t>; + ElementScalar alpha = ElementScalar(1); + ElementScalar const* alpha_ptr{}; + StrideAlpha dAlpha{}; + + using StrideBeta = Stride<_0,_0,int64_t>; + ElementScalar beta = ElementScalar(0); + ElementScalar const* beta_ptr{}; + StrideBeta dBeta{}; + + using StrideScale = Stride<_0,_0,int64_t>; + ElementScalar scale = ElementScalar(1); + ElementScalar const* scale_ptr{}; + StrideScale dScale{}; + + using StrideOutput = cutlass::gemm::TagToStrideC_t; + ElementOutput* ptr_D{}; + StrideOutput dD{}; + + int sm_count{}; + + operator typename Impl::Arguments() const { + + using StoreArgs = decltype(typename Impl::Arguments{}.op_1.op_1); + + StoreArgs store_args = [&]{ + if constexpr (Quantize) { + return StoreArgs + { // custom node : conversion + store + { // binary op : conversion + scale + {{scale}, {scale_ptr}, {dScale}}, // leaf args : scalar broadcast (scale) + {}, // leaf args : acc fetch (input) + {} // binary args : multiplies + }, + {ptr_D, dD}, // unary op : aux store + }; + } + else { + return StoreArgs + { // unary op : aux store + {}, // leaf args : acc fetch (input) + {ptr_D, dD} // unary args : aux store + }; + } + }(); + + return + { // unary op: store(scale(gated_act(beta * C + (alpha * acc)))) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, + {} // ternary args : multiply_add + }, + { // custom node : gated_act+scale+store custom node + { // unary op : act_func(input) + {}, // leaf args : input + {} // unary args : act_func + }, + store_args + } + }; + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +template < + // DispatchPolicy args + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups, + // Fusion op args + // Gated act + quantization args + bool Quantize, + template class ActivationFn, + // Store args + class GmemLayoutTagOutput, + // Element types + class ElementOutput, + class ElementCompute, + class ElementSource, + class ElementScalar, + class ElementIntermediate, + int Alignment, + FloatRoundStyle RoundStyle, + // Tile shape args + class CtaTileShapeMNK, + class EpilogueTile, + // Aux store args + class SmemLayoutAtom, + class CopyOpR2S +> +struct FusionCallbacks< + epilogue::Sm90PtrArrayTmaWarpSpecialized, + fusion::LinCombGatedActFunc, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpR2S +> : Sm90LinCombGatedActFunc, + SmemLayoutAtom, CopyOpR2S, ElementOutput, ElementCompute, ElementSource, ElementScalar, ElementIntermediate, RoundStyle> { + + using Impl = Sm90LinCombGatedActFunc, + SmemLayoutAtom, CopyOpR2S, ElementOutput, ElementCompute, ElementSource, ElementScalar, ElementIntermediate, RoundStyle>; + using Operation = fusion::LinCombGatedActFunc; + + struct Arguments { + using StrideAlpha = Stride<_0,_0,int64_t>; + ElementScalar alpha = ElementScalar(1); + ElementScalar const* alpha_ptr{}; + ElementScalar const* const* alpha_ptr_array{}; + StrideAlpha dAlpha{}; + + using StrideBeta = Stride<_0,_0,int64_t>; + ElementScalar beta = ElementScalar(0); + ElementScalar const* beta_ptr{}; + ElementScalar const* const* beta_ptr_array{}; + StrideBeta dBeta{}; + + using StrideScale = Stride<_0,_0,int64_t>; + ElementScalar scale = ElementScalar(1); + ElementScalar const* scale_ptr{}; + ElementScalar const* const* scale_ptr_array{}; + StrideScale dScale{}; + + using StrideOutput = cutlass::gemm::TagToStrideC_t; + ElementOutput** ptr_D{}; + StrideOutput dD{}; + + int sm_count{}; + + operator typename Impl::Arguments() const { + + using StoreArgs = decltype(typename Impl::Arguments{}.op_1.op_1); + + StoreArgs store_args = [&]{ + if constexpr (Quantize) { + return StoreArgs + { // custom node : conversion + store + { // binary op : conversion + scale + {{scale}, {scale_ptr}, {scale_ptr_array}, {dScale}}, // leaf args : scalar broadcast (scale) + {}, // leaf args : acc fetch (input) + {} // binary args : multiplies + }, + {ptr_D, dD, sm_count}, // unary op : aux store + }; + } + else { + return StoreArgs + { // unary op : aux store + {}, // leaf args : acc fetch (input) + {ptr_D, dD, sm_count} // unary args : aux store + }; + } + }(); + + return + { // unary op: store(scale(gated_act(beta * C + (alpha * acc)))) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, + {} // ternary args : multiply_add + }, + { // custom node : gated_act+scale+store custom node + { // unary op : act_func(input) + {}, // leaf args : input + {} // unary args : act_func + }, + store_args + } + }; + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +} // namespace cutlass::epilogue::fusion diff --git a/examples/113_hopper_gemm_activation_fusion/tile_scheduler_group.hpp b/examples/113_hopper_gemm_activation_fusion/tile_scheduler_group.hpp new file mode 100644 index 000000000..1934965b8 --- /dev/null +++ b/examples/113_hopper_gemm_activation_fusion/tile_scheduler_group.hpp @@ -0,0 +1,132 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/gemm/kernel/tile_scheduler.hpp" + +// A version of Persistent Group scheduler that preserves multimodal tiling +template +class PersistentTileSchedulerSm90GroupTileShapeDependent +: public cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Group { + +public: + + static_assert(cute::is_static_v, "TileShape must be static"); + + using Base = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Group; + using Base::Base; + using WorkTileInfo = typename Base::WorkTileInfo; + + // Customize this function to pass static (hierarchical) tile shape instead of dynamic flattened tile shape + CUTLASS_DEVICE + WorkTileInfo + get_current_work_for_linear_idx(uint64_t linear_idx) { + if (this->scheduler_params.pre_processed_problem_shapes && linear_idx >= this->scheduler_params.blocks_across_problem_) { + return WorkTileInfo::invalid_work_tile(); + } + + return this->template get_work_idx_m_and_n( + linear_idx, + this->current_group_info_, + this->scheduler_params.problem_shapes_, + this->cached_problem_shapes_, + TileShape{}, + this->scheduler_params.cluster_shape_, + this->scheduler_params.divmod_cluster_shape_major_, + this->scheduler_params.divmod_cluster_shape_minor_, + this->scheduler_params.divmod_cta_shape_m_, + this->scheduler_params.divmod_cta_shape_n_, + this->scheduler_params.max_swizzle_size_, + this->scheduler_params.raster_order_); + } + + // Must re-implement every function that calls get_current_work_for_linear_idx() to get the call to resolve to correct version + + template + CUTLASS_DEVICE + auto + advance_to_next_work( + TileSchedulerPipeline& scheduler_pipeline, + TileSchedulerPipelineState scheduler_pipe_producer_state, + uint32_t advance_count = 1, + CallbackBeforeCommit callback_before_commit = [] (WorkTileInfo info) { return info; }) { + + this->current_work_linear_idx_ += this->total_grid_size_ * uint64_t(advance_count); + auto work_tile = get_current_work_for_linear_idx(this->current_work_linear_idx_); + using WorkTileWithCallbackInfo = decltype(callback_before_commit(work_tile)); + WorkTileWithCallbackInfo work_tile_with_callback_info = work_tile; + scheduler_pipeline.producer_acquire(scheduler_pipe_producer_state); + if (work_tile_with_callback_info.is_valid()) { + work_tile_with_callback_info = callback_before_commit(work_tile); + } + if (cute::elect_one_sync()) { + reinterpret_cast(this->response_ptr_)[scheduler_pipe_producer_state.index()] = work_tile_with_callback_info; + cutlass::arch::fence_view_async_shared(); + scheduler_pipeline.producer_commit(scheduler_pipe_producer_state); + } + return cute::make_tuple(work_tile_with_callback_info, true); + } + + // Returns the initial work tile info that will be computed over + template + CUTLASS_DEVICE + WorkTileInfo + initial_work_tile_info(ClusterShape) { + return get_current_work_for_linear_idx(this->current_work_linear_idx_); + } +}; + +// Derives from GroupScheduler so the cooperative kernel's scheduler-compatibility +// static_assert (is_base_of_v) accepts this tag. +struct GroupSchedulerTileShapeDependent : cutlass::gemm::GroupScheduler {}; + +namespace cutlass::gemm::kernel::detail { + +template < + class TileShape, + class ClusterShape, + uint32_t SchedulerPipelineStageCount, + class GroupProblemShape +> +struct TileSchedulerSelector< + GroupSchedulerTileShapeDependent, + arch::Sm90, + TileShape, + ClusterShape + , SchedulerPipelineStageCount + , GroupProblemShape + > { + using Scheduler = PersistentTileSchedulerSm90GroupTileShapeDependent; +}; + +} diff --git a/examples/113_hopper_gemm_activation_fusion/utils.hpp b/examples/113_hopper_gemm_activation_fusion/utils.hpp new file mode 100644 index 000000000..6a8c00b49 --- /dev/null +++ b/examples/113_hopper_gemm_activation_fusion/utils.hpp @@ -0,0 +1,158 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +template +std::string kernel_schedule_string() { + if constexpr (cute::is_base_of_v) { + return "Non-persistent"; + } + else if constexpr (cute::is_base_of_v || + cute::is_base_of_v) { + return "Pingpong"; + } + else if constexpr (cute::is_base_of_v || + cute::is_base_of_v) { + return "Cooperative"; + } + else { + return "Unknown"; + } +} + +template