From 1b4e24470a369fc0dfc12987c2a43036207b4f04 Mon Sep 17 00:00:00 2001 From: dan_the_3rd <43445237+danthe3rd@users.noreply.github.com> Date: Wed, 26 Oct 2022 20:04:42 +0200 Subject: [PATCH] Example 43 - DualGemm (#670) * Ex50 wip * IS_PROFILING mode * MultiStage2 - but is slower * Add SwiGLU * Support SplitKSerial reduction Support not storing D0/D1 Cleanup code * Option to disable bias * Renumber example * Fix build * Remove references to pb_size_0 / pb_size_1 * Add support for bf16 inputs with float accum * small changes Co-authored-by: danthe3rd Co-authored-by: Haicheng Wu --- examples/43_dual_gemm/CMakeLists.txt | 36 + examples/43_dual_gemm/device/dual_gemm.h | 457 ++++++++++ examples/43_dual_gemm/dual_gemm.cu | 262 ++++++ examples/43_dual_gemm/dual_gemm_run.h | 829 ++++++++++++++++++ examples/43_dual_gemm/kernel/dual_gemm.h | 489 +++++++++++ examples/43_dual_gemm/test_run.h | 95 ++ .../43_dual_gemm/thread/left_silu_and_mul.h | 150 ++++ .../43_dual_gemm/threadblock/dual_epilogue.h | 430 +++++++++ .../43_dual_gemm/threadblock/dual_mma_base.h | 218 +++++ .../threadblock/dual_mma_multistage.h | 760 ++++++++++++++++ examples/CMakeLists.txt | 1 + include/cutlass/gemm/threadblock/mma_base.h | 1 + 12 files changed, 3728 insertions(+) create mode 100644 examples/43_dual_gemm/CMakeLists.txt create mode 100644 examples/43_dual_gemm/device/dual_gemm.h create mode 100644 examples/43_dual_gemm/dual_gemm.cu create mode 100644 examples/43_dual_gemm/dual_gemm_run.h create mode 100644 examples/43_dual_gemm/kernel/dual_gemm.h create mode 100644 examples/43_dual_gemm/test_run.h create mode 100644 examples/43_dual_gemm/thread/left_silu_and_mul.h create mode 100644 examples/43_dual_gemm/threadblock/dual_epilogue.h create mode 100644 examples/43_dual_gemm/threadblock/dual_mma_base.h create mode 100644 examples/43_dual_gemm/threadblock/dual_mma_multistage.h diff --git a/examples/43_dual_gemm/CMakeLists.txt b/examples/43_dual_gemm/CMakeLists.txt new file mode 100644 index 000000000..8433b1af9 --- /dev/null +++ b/examples/43_dual_gemm/CMakeLists.txt @@ -0,0 +1,36 @@ + +# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +cutlass_example_add_executable( + 43_dual_gemm + dual_gemm.cu + ) + diff --git a/examples/43_dual_gemm/device/dual_gemm.h b/examples/43_dual_gemm/device/dual_gemm.h new file mode 100644 index 000000000..81e301cdc --- /dev/null +++ b/examples/43_dual_gemm/device/dual_gemm.h @@ -0,0 +1,457 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Performs a dual gemm in one fused kernel: +``` +D0 = epilogue0(X @ B0, C0) +D1 = epilogue1(X @ B1, C1) +D2 = element_wise(D0, D1) +``` +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" + +#include "../kernel/dual_gemm.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp0_, + typename EpilogueOutputOp1_, + typename EpilogueOutputOp2_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + bool StoreD0 = true, + bool StoreD1 = true, + /// If true, kernel supports split-K with serial reduction + bool SplitKSerial = false, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator> +class DualGemm { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp0 = EpilogueOutputOp0_; + using EpilogueOutputOp1 = EpilogueOutputOp1_; + using EpilogueOutputOp2 = EpilogueOutputOp2_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp1::kCount; + static bool const kSplitKSerial = SplitKSerial; + static bool constexpr kStoreD0 = StoreD0; + static bool constexpr kStoreD1 = StoreD1; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + using LayoutScaleBias = layout::RowMajor; + /// Define the kernel + /// Define the threadblock-scoped matrix multiply-accumulate + static_assert(ArchTag::kMinComputeCapability >= 80, "Only multistage is implemented"); + static_assert(kStages >= 3, "Only multistage is implemented"); + using Mma = typename cutlass::gemm::threadblock::DefaultMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, + ThreadblockShape, WarpShape, + InstructionShape, Stages, Operator>::ThreadblockMma; + using DualMma = threadblock::DualMmaMultistage< + typename Mma::Shape, + typename Mma::IteratorA, + typename Mma::SmemIteratorA, + Mma::kCacheOpA, + typename Mma::IteratorB, + typename Mma::SmemIteratorB, + Mma::kCacheOpB, + typename Mma::ElementC, + typename Mma::LayoutC, + typename Mma::Policy, + Mma::kStages, + SharedMemoryClearOption::kNone + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using Epilogue0 = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, typename DualMma::Operator, kPartitionsK, EpilogueOutputOp0, + EpilogueOutputOp0::kCount>::Epilogue; + using Epilogue1 = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, typename DualMma::Operator, kPartitionsK, EpilogueOutputOp1, + EpilogueOutputOp1::kCount>::Epilogue; + + /// Define the kernel-level GEMM operator. + using DualGemmKernel = kernel::DualGemm< + DualMma, + Epilogue0, Epilogue1, EpilogueOutputOp2, + ThreadblockSwizzle, kSplitKSerial, + kStoreD0, kStoreD1>; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord problem_size; + TensorRef ref_A0; + TensorRef ref_B0; + TensorRef ref_C0; + TensorRef ref_D0; + TensorRef ref_B1; + TensorRef ref_C1; + TensorRef ref_D1; + TensorRef ref_D2; + typename EpilogueOutputOp0::Params epilogue0; + typename EpilogueOutputOp1::Params epilogue1; + typename EpilogueOutputOp2::Params epilogue2; + int split_k_slices; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments(): problem_size(0, 0, 0), split_k_slices(1) { + + } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A0_, + TensorRef ref_B0_, + TensorRef ref_C0_, + TensorRef ref_D0_, + TensorRef ref_B1_, + TensorRef ref_C1_, + TensorRef ref_D1_, + TensorRef ref_D2_, + typename EpilogueOutputOp0::Params epilogue0_ = + typename EpilogueOutputOp0::Params(), + typename EpilogueOutputOp1::Params epilogue1_ = + typename EpilogueOutputOp1::Params(), + typename EpilogueOutputOp2::Params epilogue2_ = + typename EpilogueOutputOp2::Params(), + int split_k_slices_ = 1 + ): + problem_size(problem_size_), + ref_A0(ref_A0_), + ref_B0(ref_B0_), + ref_C0(ref_C0_), + ref_D0(ref_D0_), + ref_B1(ref_B1_), + ref_C1(ref_C1_), + ref_D1(ref_D1_), + ref_D2(ref_D2_), + epilogue0(epilogue0_), + epilogue1(epilogue1_), + epilogue2(epilogue2_), + split_k_slices(split_k_slices_) { + + } + }; + +private: + + /// Kernel parameters object + typename DualGemmKernel::Params params_; + +public: + + /// Constructs the GEMM. + DualGemm() = default; + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + if (!kSplitKSerial && args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + if (kStoreD0 != (args.ref_D0.data() != nullptr)) { + return Status::kErrorInternal; + } + if (kStoreD1 != (args.ref_D1.data() != nullptr)) { + return Status::kErrorInternal; + } + + Status status = DualGemmKernel::can_implement( + args.problem_size, + args.ref_A0.non_const_ref(), + args.ref_B0.non_const_ref(), + args.ref_C0.non_const_ref(), + args.ref_D0, + args.ref_B1.non_const_ref(), + args.ref_C1.non_const_ref(), + args.ref_D1, + args.ref_D2 + ); + + if (status != Status::kSuccess) { + return status; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + if (kSplitKSerial && args.split_k_slices > 1) { + + + bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); + } + + return bytes; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + if (kSplitKSerial) { + if (args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + size_t bytes = get_workspace_size(args); + + cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } + else { + + if (args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + } + + // Initialize the Params structure + params_ = typename DualGemmKernel::Params{ + args.problem_size, + grid_shape, + args.ref_A0.non_const_ref(), + args.ref_B0.non_const_ref(), + args.ref_C0.non_const_ref(), + args.ref_D0, + args.ref_B1.non_const_ref(), + args.ref_C1.non_const_ref(), + args.ref_D1, + args.ref_D2, + args.epilogue0, + args.epilogue1, + args.epilogue2, + reinterpret_cast(workspace), + }; + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + if (kSplitKSerial && args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + } + + params_.ref_A0.reset(args.ref_A0.non_const_ref().data()); + params_.ref_B0.reset(args.ref_B0.non_const_ref().data()); + params_.ref_C0.reset(args.ref_C0.non_const_ref().data()); + params_.ref_D0.reset(args.ref_D0.data()); + params_.ref_B1.reset(args.ref_B1.non_const_ref().data()); + params_.ref_C1.reset(args.ref_C1.non_const_ref().data()); + params_.ref_D1.reset(args.ref_D1.data()); + params_.ref_D2.reset(args.ref_D2.data()); + params_.output_op_0 = args.epilogue0; + params_.output_op_1 = args.epilogue1; + params_.output_op_2 = args.epilogue2; + params_.semaphore = reinterpret_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(DualGemmKernel::kThreadCount, 1, 1); + + cudaError_t result; + + int smem_size = int(sizeof(typename DualGemmKernel::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + cutlass::Kernel<<>>(params_); + + result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/43_dual_gemm/dual_gemm.cu b/examples/43_dual_gemm/dual_gemm.cu new file mode 100644 index 000000000..cbe76756e --- /dev/null +++ b/examples/43_dual_gemm/dual_gemm.cu @@ -0,0 +1,262 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief CUTLASS Dual-GEMM Example. + + Fused kernel that outputs `D0` and `D1`. + We assume that B0/B1 have the same shape/layout + +``` +D0 = epilogue0(X @ B0, C0) +D1 = epilogue1(X @ B1, C1) +D2 = element_wise(D0, D1) +``` + D0 and D1 will be optionally stored in gmem (`kStoreD0` / `kStoreD1`) +*/ + +// #define IS_PROFILING + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "device/dual_gemm.h" +#include "thread/left_silu_and_mul.h" +#include "dual_gemm_run.h" +#include "test_run.h" + + +//////////////////////////////////////////////////////////////////////////////// + +cutlass::gemm::GemmCoord problem_size(4096, 4096, 8192); + +constexpr int kStages = 3; +constexpr bool kSplitKSerial = false; +constexpr bool kUseBias = true; + + +#if 0 +using ElementOperandA = cutlass::bfloat16_t; +using ElementOperandB = cutlass::bfloat16_t; +using ElementOutput = cutlass::bfloat16_t; +using ElementAccumulator = float; +using ElementCompute = float; +#else +using ElementOperandA = cutlass::half_t; +using ElementOperandB = cutlass::half_t; +using ElementOutput = cutlass::half_t; +using ElementAccumulator = cutlass::half_t; +using ElementCompute = cutlass::half_t; +#endif + +constexpr auto kScaleType = kUseBias ? cutlass::epilogue::thread::ScaleType::NoBetaScaling : ( + // No bias + kSplitKSerial ? cutlass::epilogue::thread::ScaleType::Default : cutlass::epilogue::thread::ScaleType::Nothing +); +using EpilogueOutputOp0 = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute, + kScaleType +>; +using EpilogueOutputOp1 = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute, + kScaleType +>; +using EpilogueOutputOp2 = cutlass::epilogue::thread::LeftSiLUAndMul< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementOutput, + ElementCompute +>; + +const ElementCompute alpha0 = ElementCompute(1); +const ElementCompute beta0 = ElementCompute(kUseBias ? 1 : 0); +const ElementCompute alpha1 = ElementCompute(1); +const ElementCompute beta1 = ElementCompute(kUseBias ? 1 : 0); + +bool run_nonfused_gemm_f16_sm80() { + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + using Gemm0 = cutlass::gemm::device::Gemm< + ElementOperandA, + cutlass::layout::RowMajor, + ElementOperandB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp0, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, + kStages, + 8, + 8, + kSplitKSerial + >; + using Gemm1 = cutlass::gemm::device::Gemm< + ElementOperandA, + cutlass::layout::RowMajor, + ElementOperandB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp1, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, + kStages, + 8, + 8, + kSplitKSerial + >; + + NonFusedDualGemmRun nonFusedGemm; + + std::cout << "Running Non-fused GEMMs FP16 TN GEMMs...\n"; + bool pass = nonFusedGemm.run(problem_size, alpha0, beta0, alpha1, beta1); + if(pass) + std::cout << "Pass\n"; + else + std::cout << "Fail\n"; + + return pass; +} + +template +struct LeftSiLUAndMul { + struct Params{}; + CUTLASS_HOST_DEVICE LeftSiLUAndMul(Params p) {} + + CUTLASS_HOST_DEVICE void set_k_partition(int, int) {} + + CUTLASS_HOST_DEVICE T operator() ( + T const &lhs, + T const &rhs) const { + cutlass::epilogue::thread::SiLu silu; + cutlass::multiplies mul; + auto silu_lhs = silu(lhs); + return mul(silu_lhs, rhs); + } + + template + CUTLASS_HOST_DEVICE cutlass::Array operator() ( + cutlass::Array const &lhs, + cutlass::Array const &rhs) const { + cutlass::epilogue::thread::SiLu silu; + cutlass::multiplies mul; + auto silu_lhs = silu(lhs); + return mul(silu_lhs, rhs); + } +}; + +bool run_fused_gemm_f16_sm80_shmem() { + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + // Optionally, we might not need intermediate GEMM outputs + constexpr bool kStoreD0 = true; + constexpr bool kStoreD1 = true; + + using DualGemm = cutlass::gemm::device::DualGemm< + ElementOperandA, + cutlass::layout::RowMajor, + ElementOperandB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + EpilogueOutputOp2, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, + kStages, + kStoreD0, + kStoreD1, + kSplitKSerial + >; + + DualFusedGemmRun fusedGemm; + + std::cout << "Running Fused FP16 TN GEMMs + Epilogue2...\n"; + bool passed = fusedGemm.run(problem_size, alpha0, beta0, alpha1, beta1); + if(passed) + std::cout << "Pass\n"; + else + std::cout << "Fail\n"; + + return passed; + +} + +int main() { + + std::vectorfuncs = { + &run_nonfused_gemm_f16_sm80, + &run_fused_gemm_f16_sm80_shmem + }; + + std::string test_name = "dual-gemm f16 bias=" + std::to_string(kUseBias) + " split_k_serial=" + std::to_string(kSplitKSerial); + return testRun(80, funcs, test_name); +} + + + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/43_dual_gemm/dual_gemm_run.h b/examples/43_dual_gemm/dual_gemm_run.h new file mode 100644 index 000000000..df7c04b4e --- /dev/null +++ b/examples/43_dual_gemm/dual_gemm_run.h @@ -0,0 +1,829 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_relu.h" + +#include "helper.h" + +#define CHECK_GT(val1, val2) \ + if((val1) <= (val2)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; +#define CHECK_TRUE(val) \ + if(!(val)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; + +template < + typename OutputOp, + typename Element, + typename Layout> +struct TensorEpilogueForEachFunc { + /// View type + using TensorView = cutlass::TensorView; + + /// Coordinate in tensor's index space + using TensorCoord = typename TensorView::TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view_x0; + TensorView view_x1; + TensorView view_y; + OutputOp output_op; + + + // + // Methods + // + + Params( + TensorView view_x0_ = TensorView(), + TensorView view_x1_ = TensorView(), + TensorView view_y_ = TensorView(), + OutputOp output_op_ = OutputOp(typename OutputOp::Params{}) + ): + view_x0(view_x0_), view_x1(view_x1_), view_y(view_y_), output_op(output_op_) { + } + }; + + Params params; + + CUTLASS_DEVICE + TensorEpilogueForEachFunc(Params const ¶ms): params(params) { + + } + + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + Element const & x0 = params.view_x0.at(coord); + Element const & x1 = params.view_x1.at(coord); + Element& y = params.view_y.at(coord); + y = params.output_op(x0, x1); + } +}; + +template < + typename OutputOp, + typename Element, + typename Layout> +void TensorEpilogueForEach( + cutlass::TensorView x0, + cutlass::TensorView x1, + cutlass::TensorView y) { + + using Func = TensorEpilogueForEachFunc; + using Params = typename Func::Params; + + cutlass::reference::device::TensorForEach( + y.extent(), + Params(x0, x1, y) + ); +} + +//////////////////////////////////////////////////////////////////////////////// + +template +struct NonFusedDualGemmRun +{ + + using Gemm0 = Gemm0_; + using Gemm1 = Gemm1_; + using ElementAccumulator = typename Gemm0::ElementAccumulator; + using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + // + // Methods + // + + NonFusedDualGemmRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + std::cerr << "Not implemented\n"; + return false; + } + + return true; + } + + + + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + bool relu = false, + int warm_ups = 1, + int runs = 100) { + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor< + typename Gemm0::ElementA, + typename Gemm0::LayoutA> tensor_A0(problem_size.mk()); + + cutlass::HostTensor< + typename Gemm0::ElementB, + typename Gemm0::LayoutB> tensor_B0(problem_size.kn()); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> tensor_C0(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm0::LayoutC> tensor_Bias0({1, problem_size.n()}); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> tensor_D0(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> reference_D0(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementB, + typename Gemm1::LayoutB> tensor_B1(problem_size.kn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> tensor_C1(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> tensor_Bias1({1, problem_size.n()}); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> tensor_D1(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> reference_D1(problem_size.mn()); + + + CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); + CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); + CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); + CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014)); + CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); + CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); + CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013)); + + cutlass::reference::host::TensorFill( + tensor_D0.host_view()); + cutlass::reference::host::TensorFill( + tensor_D1.host_view()); + cutlass::reference::host::TensorFill( + reference_D0.host_view()); + cutlass::reference::host::TensorFill( + reference_D1.host_view()); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_C0.sync_device(); + tensor_Bias0.sync_device(); + tensor_D0.sync_device(); + reference_D0.sync_device(); + tensor_B1.sync_device(); + tensor_C1.sync_device(); + tensor_Bias1.sync_device(); + tensor_D1.sync_device(); + reference_D1.sync_device(); + + // + // Initialize the GEMM operator + // + + int split_k_slices = Gemm0::kSplitKSerial ? 2 : 1; + typename Gemm0::Arguments arguments_0{ + problem_size, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, + tensor_D0.device_ref(), + {alpha0, beta0}, + split_k_slices + }; + + split_k_slices = Gemm1::kSplitKSerial ? 2 : 1; + typename Gemm1::Arguments arguments_1{ + problem_size, + tensor_A0.device_ref(), + tensor_B1.device_ref(), + {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, + tensor_D1.device_ref(), + {alpha1, beta1}, + split_k_slices + }; + + + Gemm0 gemm_op_0; + Gemm1 gemm_op_1; + + // Allocate workspace memory + cutlass::device_memory::allocation workspace0(gemm_op_0.get_workspace_size(arguments_0)); + cutlass::device_memory::allocation workspace1(gemm_op_1.get_workspace_size(arguments_1)); + + cutlass::Status status = gemm_op_0.initialize(arguments_0, workspace0.get()); + + CUTLASS_CHECK(status); + + status = gemm_op_1.initialize(arguments_1, workspace1.get()); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = gemm_op_0(); + CUTLASS_CHECK(status); + status = gemm_op_1(); + CUTLASS_CHECK(status); + } +#ifdef IS_PROFILING + return true; +#endif + // + // Run the GEMM + // + cudaEvent_t start, stop1, stop2; + cudaEventCreate(&start); + cudaEventCreate(&stop1); + cudaEventCreate(&stop2); + + cudaEventRecord(start); + + for(int i = 0; i < runs; i++) { + status = gemm_op_0(); + + CUTLASS_CHECK(status); + } + cudaEventRecord(stop1); + for(int i = 0; i < runs; i++) { + status = gemm_op_1(); + + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop2); + cudaDeviceSynchronize(); + float gemm0Time, gemm1Time, totalTime; + cudaEventElapsedTime(&gemm0Time, start, stop1); + cudaEventElapsedTime(&gemm1Time, stop1, stop2); + cudaEventElapsedTime(&totalTime, start, stop2); + std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n"; + std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n"; + std::cout << "Non-fusion GEMM only time " << totalTime / (float)runs << " ms\n"; + + tensor_D0.sync_host(); + tensor_D1.sync_host(); + + // + // Verify + // + cutlass::reference::device::Gemm< + typename Gemm0::ElementA, typename Gemm0::LayoutA, + typename Gemm0::ElementB, typename Gemm0::LayoutB, + typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm0::Operator> + reference_gemm_0; + + cutlass::reference::device::Gemm< + typename Gemm1::ElementA, typename Gemm1::LayoutA, + typename Gemm1::ElementB, typename Gemm1::LayoutB, + typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm1::Operator> + reference_gemm_1; + + reference_gemm_0( + problem_size, + alpha0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + beta0, + {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, + reference_D0.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D0.device_view()); + } + + reference_gemm_1( + problem_size, + alpha1, + tensor_A0.device_ref(), + tensor_B1.device_ref(), + beta1, + {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, + reference_D1.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D1.device_view()); + } + + // Wait for kernels to finish + cudaDeviceSynchronize(); + reference_D0.sync_host(); + reference_D1.sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); + + bool passed0 = cutlass::reference::host::TensorEquals( + reference_D1.host_view(), + tensor_D1.host_view()); + CHECK_TRUE(passed0); + + bool passed1 = cutlass::reference::host::TensorEquals( + reference_D1.host_view(), + tensor_D1.host_view()); + CHECK_TRUE(passed1); + if (!passed0 || !passed1) { + + std::stringstream fname; + + fname << "error_DualGemm_device_nonfused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream file(fname.str()); + + file + << "A0 =\n" << tensor_A0.host_view() + << "\nB0 =\n" << tensor_B0.host_view() + << "\nC0 =\n" << tensor_C0.host_view() + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" + << "\nD0 =\n" << tensor_D0.host_view() + << "\nB1 =\n" << tensor_B1.host_view() + << "\nC1 =\n" << tensor_C1.host_view() + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" + << "\n\nReference =\n" << reference_D1.host_view() + << "\nComputed =\n" << tensor_D1.host_view(); + } + return passed0 && passed1; + } +}; + +template +struct DualFusedGemmRun +{ + + using DualGemm = DualGemm_; + using ElementAccumulator = typename DualGemm::ElementAccumulator; + using ElementCompute = typename DualGemm::DualGemmKernel::Epilogue0::OutputOp::ElementCompute; + using EpilogueOutputOp2 = typename DualGemm::EpilogueOutputOp2; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Scale; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + // + // Methods + // + + DualFusedGemmRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), + init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + std::cerr << "Not implemented\n"; + return false; + } + + return true; + } + + + + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(1), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(1), + bool relu = false, + int warm_ups = 1, + int runs = 100) { + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor< + typename DualGemm::ElementA, + typename DualGemm::LayoutA> tensor_A0(problem_size.mk()); + + cutlass::HostTensor< + typename DualGemm::ElementB, + typename DualGemm::LayoutB> tensor_B0(problem_size.kn()); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> tensor_C0(problem_size.mn()); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutScaleBias> tensor_Bias0({1, problem_size.n()}); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> tensor_D0(problem_size.mn()); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> reference_D0(problem_size.mn()); + + cutlass::HostTensor< + typename DualGemm::ElementB, + typename DualGemm::LayoutB> tensor_B1(problem_size.kn()); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> tensor_C1(problem_size.mn()); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutScaleBias> tensor_Bias1({1, problem_size.n()}); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> tensor_D1(problem_size.mn()); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> tensor_D2(problem_size.mn()); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> reference_D1(problem_size.mn()); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> reference_D2(problem_size.mn()); + + CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); + CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2118)); + CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); + CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2011)); + CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2113)); + CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); + CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012)); + + cutlass::reference::host::TensorFill( + tensor_D0.host_view()); + cutlass::reference::host::TensorFill( + tensor_D1.host_view()); + cutlass::reference::host::TensorFill( + tensor_D2.host_view()); + cutlass::reference::host::TensorFill( + reference_D0.host_view()); + cutlass::reference::host::TensorFill( + reference_D1.host_view()); + cutlass::reference::host::TensorFill( + reference_D2.host_view()); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_C0.sync_device(); + tensor_Bias0.sync_device(); + tensor_B1.sync_device(); + tensor_C1.sync_device(); + tensor_Bias1.sync_device(); + tensor_D0.sync_device(); + tensor_D1.sync_device(); + tensor_D2.sync_device(); + reference_D0.sync_device(); + reference_D1.sync_device(); + reference_D2.sync_device(); + + // + // Initialize the GEMM operator + // + + int split_k_slices = DualGemm::kSplitKSerial ? 2 : 1; + typename cutlass::TensorRef nullptr_ref{}; + decltype(nullptr_ref) ref_B0, ref_B1; + if (beta0 != ElementCompute(0)) { + ref_B0 = {tensor_Bias0.device_data(), typename DualGemm::LayoutC::Stride(0)}; + } + if (beta1 != ElementCompute(0)) { + ref_B1 = {tensor_Bias1.device_data(), typename DualGemm::LayoutC::Stride(0)}; + } + typename DualGemm::Arguments arguments{ + problem_size, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + ref_B0, + DualGemm::kStoreD0 ? tensor_D0.device_ref() : nullptr_ref, + tensor_B1.device_ref(), + ref_B1, + DualGemm::kStoreD1 ? tensor_D1.device_ref() : nullptr_ref, + tensor_D2.device_ref(), + {alpha0, beta0}, + {alpha1, beta1}, + {}, + split_k_slices + }; + + DualGemm b2b_gemm_op; + + cutlass::device_memory::allocation workspace(b2b_gemm_op.get_workspace_size(arguments)); + + cutlass::Status status = b2b_gemm_op.can_implement(arguments); + + CUTLASS_CHECK(status); + + status = b2b_gemm_op.initialize(arguments, workspace.get()); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = b2b_gemm_op(); + CUTLASS_CHECK(status); + } + +#ifdef IS_PROFILING + return true; +#endif + // + // Run the GEMM + // + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + cudaEventRecord(start); + + for(int i = 0; i < runs; i++) { + status = b2b_gemm_op(); + + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop); + cudaDeviceSynchronize(); + float gemmTime; + cudaEventElapsedTime(&gemmTime, start, stop); + std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n"; + + tensor_D0.sync_host(); + tensor_D1.sync_host(); + tensor_D2.sync_host(); + + // + // Verify + // + + cutlass::reference::device::Gemm< + typename DualGemm::ElementA, typename DualGemm::LayoutA, + typename DualGemm::ElementB, typename DualGemm::LayoutB, + typename DualGemm::ElementC, typename DualGemm::LayoutC, + ElementAccumulator, ElementAccumulator> + reference_gemm_0; + + cutlass::reference::device::Gemm< + typename DualGemm::ElementA, typename DualGemm::LayoutA, + typename DualGemm::ElementB, typename DualGemm::LayoutB, + typename DualGemm::ElementC, typename DualGemm::LayoutC, ElementCompute, + ElementAccumulator, typename DualGemm::Operator> + reference_gemm_1; + + reference_gemm_0( + problem_size, + alpha0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + beta0, + {tensor_Bias0.device_data(), typename DualGemm::LayoutC::Stride(0)}, + reference_D0.device_ref() + ); + if(relu) { + cutlass::reference::device::TensorReLu(reference_D0.device_view()); + } + + reference_gemm_1( + problem_size, + alpha1, + tensor_A0.device_ref(), + tensor_B1.device_ref(), + beta1, + {tensor_Bias1.device_data(), typename DualGemm::LayoutC::Stride(0)}, + reference_D1.device_ref() + ); + if(relu) { + cutlass::reference::device::TensorReLu(reference_D1.device_view()); + } + TensorEpilogueForEach(reference_D0.device_view(), reference_D1.device_view(), reference_D2.device_view()); + cudaDeviceSynchronize(); + reference_D0.sync_host(); + reference_D1.sync_host(); + reference_D2.sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D2.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D2.host_view()), 0); + + bool passed_out0 = true; + if (DualGemm::kStoreD0) { + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); + passed_out0 = cutlass::reference::host::TensorEquals( + reference_D0.host_view(), + tensor_D0.host_view()); + } + CHECK_TRUE(passed_out0); + + bool passed_out1 = true; + if (DualGemm::kStoreD1) { + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); + passed_out1 = cutlass::reference::host::TensorEquals( + reference_D1.host_view(), + tensor_D1.host_view()); + } + CHECK_TRUE(passed_out1); + + bool passed_out2 = cutlass::reference::host::TensorEquals( + reference_D2.host_view(), + tensor_D2.host_view()); + CHECK_TRUE(passed_out2); + + bool passed = passed_out0 && passed_out1 && passed_out2; + if (!passed) + { + + std::stringstream fname; + + fname << "error_DualGemm_device_fused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream file(fname.str()); + + file + << "A0 =\n" << tensor_A0.host_view() + << "\nB0 =\n" << tensor_B0.host_view() + << "\nC0 =\n" << tensor_C0.host_view() + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" + << "\nB1 =\n" << tensor_B1.host_view() + << "\nC1 =\n" << tensor_C1.host_view() + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" + << "\n\nReference0 =\n" << reference_D0.host_view() + << "\nComputed0 =\n" << tensor_D0.host_view() + << "\n\nReference1 =\n" << reference_D1.host_view() + << "\nComputed1 =\n" << tensor_D1.host_view() + << "\n\nReference2 =\n" << reference_D2.host_view() + << "\nComputed2 =\n" << tensor_D2.host_view(); + } + //std::cout << "A0 " << tensor_A0.host_view() << std::endl; + // std::cout << "reference_D0 " << reference_D0.host_view() << std::endl; + // std::cout << "reference_D1 " << reference_D1.host_view() << std::endl; + // std::cout << "reference_D2 " << reference_D2.host_view() << std::endl; + //std::cout << "reference_D0 " << reference_D0.host_view() << std::endl; + return passed; + } + +}; + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/43_dual_gemm/kernel/dual_gemm.h b/examples/43_dual_gemm/kernel/dual_gemm.h new file mode 100644 index 000000000..aa7051b37 --- /dev/null +++ b/examples/43_dual_gemm/kernel/dual_gemm.h @@ -0,0 +1,489 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include "../threadblock/dual_mma_multistage.h" +#include "../threadblock/dual_epilogue.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename DualMma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue0_, ///! Epilogue + typename Epilogue1_, ///! Epilogue + typename OutputOp2_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + bool SplitKSerial, ///! If true, code supporting split-K via serial reduction is enabled. + bool StoreD0, + bool StoreD1 +> +struct DualGemm { + + using DualMma = DualMma_; + + using Epilogue0 = Epilogue0_; + using Epilogue1 = Epilogue1_; + using OutputOp0 = typename Epilogue0::OutputOp; + using OutputOp1 = typename Epilogue1::OutputOp; + using OutputOp2 = OutputOp2_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static constexpr bool kStoreD0 = StoreD0; + static constexpr bool kStoreD1 = StoreD1; + + using DualEpilogue = cutlass::epilogue::threadblock::DualEpilogue< + typename Epilogue0::Shape, + typename Epilogue0::WarpMmaOperator, + Epilogue0::kPartitionsK, + typename Epilogue0::OutputTileIterator, + typename Epilogue0::AccumulatorFragmentIterator, + typename Epilogue0::WarpTileIterator, + typename Epilogue0::SharedLoadIterator, + OutputOp0, + OutputOp1, + OutputOp2, + typename Epilogue0::Padding, + kStoreD0, + kStoreD1, + Epilogue0::kFragmentsPerIteration, + true // IterationsUnroll + >; + + static bool const kSplitKSerial = SplitKSerial; + static_assert(!kSplitKSerial || (kStoreD0 && kStoreD1), + "Split-K serial requires buffers for D0/D1 for reduction"); + + /// Warp count (concept: GemmShape) + using WarpCount0 = typename DualMma::WarpCount; + static int const kThreadCount = 32 * WarpCount0::kCount; + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + // Mma0 + typename DualMma::IteratorA::Params params_A0; + typename DualMma::IteratorA::TensorRef ref_A0; + typename DualMma::IteratorB::Params params_B0; + typename DualMma::IteratorB::TensorRef ref_B0; + typename Epilogue0::OutputTileIterator::Params params_C0; + typename Epilogue0::OutputTileIterator::TensorRef ref_C0; + typename Epilogue0::OutputTileIterator::Params params_D0; + typename Epilogue0::OutputTileIterator::TensorRef ref_D0; + typename OutputOp0::Params output_op_0; + + // Mma1 + typename DualMma::IteratorB::Params params_B1; + typename DualMma::IteratorB::TensorRef ref_B1; + typename Epilogue1::OutputTileIterator::Params params_C1; + typename Epilogue1::OutputTileIterator::TensorRef ref_C1; + typename Epilogue1::OutputTileIterator::Params params_D1; + typename Epilogue1::OutputTileIterator::TensorRef ref_D1; + typename OutputOp1::Params output_op_1; + + typename Epilogue1::OutputTileIterator::Params params_D2; + typename Epilogue1::OutputTileIterator::TensorRef ref_D2; + typename OutputOp2::Params output_op_2; + + int *semaphore; + int gemm_k_size; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) { } + + CUTLASS_HOST_DEVICE + Params( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + // Mma0: D0 = A @ B0 + C0 + typename DualMma::IteratorA::TensorRef ref_A0, + typename DualMma::IteratorB::TensorRef ref_B0, + typename Epilogue0::OutputTileIterator::TensorRef ref_C0, + typename Epilogue0::OutputTileIterator::TensorRef ref_D0, + // Mma1: D1 = A @ B1 + C1 + typename DualMma::IteratorB::TensorRef ref_B1, + typename Epilogue1::OutputTileIterator::TensorRef ref_C1, + typename Epilogue1::OutputTileIterator::TensorRef ref_D1, + + typename Epilogue1::OutputTileIterator::TensorRef ref_D2, + typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(), + typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(), + typename OutputOp2::Params output_op_2 = typename OutputOp2::Params(), + int *workspace = nullptr + ): + problem_size(problem_size), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + // Mma0 + params_A0(ref_A0.layout()), + ref_A0(ref_A0), + params_B0(ref_B0.layout()), + ref_B0(ref_B0), + params_C0(ref_C0.layout()), + ref_C0(ref_C0), + params_D0(ref_D0.layout()), + ref_D0(ref_D0), + // Mma1 + params_B1(ref_B1.layout()), + ref_B1(ref_B1), + params_C1(ref_C1.layout()), + ref_C1(ref_C1), + params_D1(ref_D1.layout()), + ref_D1(ref_D1), + params_D2(ref_D2.layout()), + ref_D2(ref_D2), + output_op_0(output_op_0), + output_op_1(output_op_1), + output_op_2(output_op_2) { + + int total_gemm_k_iterations = (problem_size.k() + DualMma::Shape::kK - 1) / DualMma::Shape::kK; + int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); + gemm_k_size = gemm_k_iterations * DualMma::Shape::kK; + + semaphore = workspace; + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename DualMma::SharedStorage main_loop; + typename DualEpilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + DualGemm() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size, + typename DualMma::IteratorA::TensorRef ref_A0, + typename DualMma::IteratorB::TensorRef ref_B0, + typename Epilogue0::OutputTileIterator::TensorRef ref_C0, + typename Epilogue0::OutputTileIterator::TensorRef ref_D0, + typename DualMma::IteratorB::TensorRef ref_B1, + typename Epilogue1::OutputTileIterator::TensorRef ref_C1, + typename Epilogue1::OutputTileIterator::TensorRef ref_D1, + typename Epilogue1::OutputTileIterator::TensorRef ref_D2) { + + static int const kAlignmentA = DualMma::IteratorA::AccessType::kElements; + static int const kAlignmentB = DualMma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue0::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(ref_A0, kAlignmentA)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_B0, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_C0, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D0, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_B1, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_C1, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D1, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D2, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A0{ + threadblock_tile_offset.m() * DualMma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B0{ + threadblock_tile_offset.k() * params.gemm_k_size, + threadblock_tile_offset.n() * DualMma::Shape::kN + }; + + cutlass::MatrixCoord tb_offset_B1{ + threadblock_tile_offset.k() * params.gemm_k_size, + threadblock_tile_offset.n() * DualMma::Shape::kN + }; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k = + (params.problem_size.k() < (threadblock_tile_offset.k() + 1) * params.gemm_k_size) ? + params.problem_size.k() : + (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - tb_offset_A0.column() + DualMma::Shape::kK - 1) / DualMma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename DualMma::IteratorA iterator_A0( + params.params_A0, + params.ref_A0.data(), + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A0); + + typename DualMma::IteratorB iterator_B0( + params.params_B0, + params.ref_B0.data(), + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B0); + + typename DualMma::IteratorB iterator_B1( + params.params_B1, + params.ref_B1.data(), + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B1); + + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + + // Construct thread-scoped matrix multiply + typename DualMma::FragmentC accum0; + typename DualMma::FragmentC accum1; + accum0.clear(); + accum1.clear(); + + DualMma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + if (!kSplitKSerial || gemm_k_iterations > 0) { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, + accum0, accum1, + iterator_A0, iterator_B0, iterator_B1, + accum0, accum1); + } + + // + // Epilogue + // + + OutputOp0 output_op_0(params.output_op_0); + OutputOp1 output_op_1(params.output_op_1); + OutputOp2 output_op_2(params.output_op_2); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * DualMma::Shape::kM, + threadblock_tile_offset.n() * DualMma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op_0.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + + // Tile iterator loading from source tensor. + typename Epilogue0::OutputTileIterator iterator_C0( + params.params_C0, + params.ref_C0.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + typename Epilogue1::OutputTileIterator iterator_C1( + params.params_C1, + params.ref_C1.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to destination tensor. + typename Epilogue0::OutputTileIterator iterator_D0( + params.params_D0, + params.ref_D0.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + typename Epilogue1::OutputTileIterator iterator_D1( + params.params_D1, + params.ref_D1.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + typename Epilogue1::OutputTileIterator iterator_D2( + params.params_D2, + params.ref_D2.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + DualEpilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C0 = iterator_D0; + iterator_C1 = iterator_D1; + } + + semaphore.wait(threadblock_tile_offset.k()); + + __threadfence(); + } + + // Execute the epilogue operator to update the destination tensor. + typename Epilogue0::OutputTileIterator source_iters[] = { + iterator_C0, iterator_C1 + }; + const bool writeToD2 = (!kSplitKSerial || params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1); + epilogue( + output_op_0, output_op_1, output_op_2, + iterator_D0, iterator_D1, iterator_D2, + accum0, accum1, + source_iters, + writeToD2 + ); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + __threadfence(); + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + diff --git a/examples/43_dual_gemm/test_run.h b/examples/43_dual_gemm/test_run.h new file mode 100644 index 000000000..b14becafc --- /dev/null +++ b/examples/43_dual_gemm/test_run.h @@ -0,0 +1,95 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#include + +// Run tests on GPUs + +int testRun(int arch, std::vector & test_funcs, const std::string & test_name) { + + bool supported = false; + + int arch_major = arch / 10; + int arch_minor = arch - arch / 10 * 10; + + if(arch_major >= 8) { + // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. + // + // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. + if (__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) { + supported = true; + } + } + else if(arch_major >= 7) { + // Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2. + // + // CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples. + if (__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) { + supported = true; + } + } + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (!(props.major == arch_major && props.minor == arch_minor)) { + supported = false; + } + + if (!supported) { + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + std::cout << "This example isn't supported on current architecture" << std::endl; + return 0; + } + + bool pass = true; + + std::cout << "Device: " << props.name << std::endl; + std::cout << "Arch: SM" << arch << std::endl; + std::cout << "Test: " << test_name << std::endl; + for(auto func : test_funcs) { + pass &= func(); + } + + + if(pass) + return 0; + else + return -1; + +} + diff --git a/examples/43_dual_gemm/thread/left_silu_and_mul.h b/examples/43_dual_gemm/thread/left_silu_and_mul.h new file mode 100644 index 000000000..d93221f18 --- /dev/null +++ b/examples/43_dual_gemm/thread/left_silu_and_mul.h @@ -0,0 +1,150 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/epilogue/thread/linear_combination_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements. +/// +/// D = alpha * accumulator + beta * source + uniform +/// +template < + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation. + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data to store + typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type + typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest +> +class LeftSiLUAndMul { +public: + + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + + static FloatRoundStyle const kRound = Round; + + struct Params{}; + +private: + + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_; + +public: + + /// Constructs the function object, possibly loading from pointers in host memory + CUTLASS_HOST_DEVICE + LeftSiLUAndMul(Params const &/*params*/) {} + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return true; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { + assert(false); + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &lhs, + FragmentAccumulator const &rhs) const { + + // Convert source to interal compute numeric type + NumericArrayConverter accumulator_to_compute; + + // Convert to destination numeric type + NumericArrayConverter compute_to_output; + + ComputeFragment converted_lhs = accumulator_to_compute(lhs); + ComputeFragment converted_rhs = accumulator_to_compute(rhs); + + cutlass::epilogue::thread::SiLu silu; + cutlass::multiplies mul; + auto silu_lhs = silu(converted_lhs); + return compute_to_output(mul(silu_lhs, converted_rhs)); + } + + CUTLASS_HOST_DEVICE + ElementOutput operator()( + ElementAccumulator const& lhs, + ElementAccumulator const& rhs + ) const { + ElementCompute convert_lhs(lhs); + ElementCompute convert_rhs(rhs); + cutlass::epilogue::thread::SiLu silu; + cutlass::multiplies mul; + auto silu_lhs = silu(convert_lhs); + return ElementOutput(mul(silu_lhs, convert_rhs)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/43_dual_gemm/threadblock/dual_epilogue.h b/examples/43_dual_gemm/threadblock/dual_epilogue.h new file mode 100644 index 000000000..54ce8f362 --- /dev/null +++ b/examples/43_dual_gemm/threadblock/dual_epilogue.h @@ -0,0 +1,430 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/vector.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/functional.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator +template < + typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) + typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) + int PartitionsK, ///< Number of partitions of the K dimension + typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors + typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators + typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM + typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM + ///< Output operator + typename OutputOp0_, + typename OutputOp1_, + typename OutputOp2_, + typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) + bool StoreD0 = true, + bool StoreD1 = true, + int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity + int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large + (!IsEpilogueFunctorHeavy::value) +> +class DualEpilogue { + +public: + + using Base = EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition>; + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + static bool constexpr kStoreD0 = StoreD0; + static bool constexpr kStoreD1 = StoreD1; + using OutputTileIterator = OutputTileIterator_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp0 = OutputOp0_; + using OutputOp1 = OutputOp1_; + using OutputOp2 = OutputOp2_; + using Padding = Padding_; + + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename Base::AccumulatorTile; + + /// Accumulator element + using ElementAccumulator = typename WarpTileIterator::Element; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + /// Tensor reference to destination tensor + using TensorRef = typename OutputTileIterator::TensorRef; + + /// Tensor reference to sync tensor + using SyncTensorRef = typename cutlass::TensorRef; + + /// Const tensor reference to source tensor + using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + + /// Array type used to output + using OutputAccessType = Array< + typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; + + /// Array type used by output functor + using AccumulatorAccessType = Array; + + /// Number of warps + using WarpCount = typename Base::WarpCount; + + struct SharedStorage { + using Element = typename WarpTileIterator::Element; + + /// Tensor reference to shared memory allocation + using TensorRef = typename WarpTileIterator::TensorRef; + + /// Logical shape of the shared memory tile written to by all warps. + using Shape = typename Base::Shape; + + /// Shape of the shared memory allocation for the epilogue + using StorageShape = typename Base::SharedStorage::StorageShape; + + // + // Data members + // + + AlignedBuffer storage[2]; + + // + // Methods + // + + /// Returns a tensor reference to the shared memory buffer + CUTLASS_DEVICE + TensorRef reference(int i) { + return TensorRef( + storage[i].data(), + Layout::packed({StorageShape::kRow, StorageShape::kColumn})); + } + }; + + static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; + static int constexpr kSmemPointerOffset = SharedStorage::StorageShape::kCount / kSmemTiles; + +public: + + static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements, + "Mismatch between shared load iterator and output tile iterator."); + + static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), + "Divisibility"); + +private: + + /// Loads fragment from shared memory aligned with output tensor + SharedLoadIterator shared_load_iterator0_; + SharedLoadIterator shared_load_iterator1_; + + /// Stores a warp's fragment of accumulators to SMEM + WarpTileIterator warp_tile_iterator0_; + WarpTileIterator warp_tile_iterator1_; + +public: + + /// Constructor + CUTLASS_DEVICE + DualEpilogue( + SharedStorage &shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx ///< Id of thread within warp + ): + shared_load_iterator0_(shared_storage.reference(0), thread_idx), + shared_load_iterator1_(shared_storage.reference(1), thread_idx), + warp_tile_iterator0_(shared_storage.reference(0), lane_idx), + warp_tile_iterator1_(shared_storage.reference(1), lane_idx) + { + int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN); + int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN); + int warp_m = warp_mn % WarpCount::kM; + int warp_n = warp_mn / WarpCount::kM; + + MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n}; + + warp_tile_iterator0_.add_tile_offset(warp_offset); + warp_tile_iterator1_.add_tile_offset(warp_offset); + } + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()( + OutputOp0 const &output_op0, + OutputOp1 const &output_op1, + OutputOp2 const &output_op2, + OutputTileIterator dest0, + OutputTileIterator dest1, + OutputTileIterator dest2, + AccumulatorTile const &accumulator0, + AccumulatorTile const &accumulator1, + OutputTileIterator source_iterator[2], + bool writeToD2 // true if it's the final split-k + ) { + // TODO: Implement when no source is needed + + typename OutputTileIterator::Fragment source_fragment[2]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + source_fragment[i].clear(); + } + + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator[2] = {accumulator0, accumulator1}; + + // + // Iterate over accumulator tile + // + + #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + + // + // Load the source + // + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + source_iterator[i].load(source_fragment[i]); + ++source_iterator[i]; + } + + // + // Convert and store fragment + // + + __syncthreads(); + + acc2smem_source_needed>::push( + iter, accum_fragment_iterator[0], this->warp_tile_iterator0_); + acc2smem_source_needed>::push( + iter, accum_fragment_iterator[1], this->warp_tile_iterator1_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + typename SharedLoadIterator::Fragment aligned_accum_fragment0[kPartitionsK]; + typename SharedLoadIterator::Fragment aligned_accum_fragment1[kPartitionsK]; + + shared_load_iterator0_.load(aligned_accum_fragment0[0]); + shared_load_iterator1_.load(aligned_accum_fragment1[0]); + + // If the number of k-slices is > 1 - perform a reduction amongst the k-slices + if (kPartitionsK > 1) { + + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for ( int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator0_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator1_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator0_.load(aligned_accum_fragment0[i]); + shared_load_iterator1_.load(aligned_accum_fragment1[i]); + aligned_accum_fragment0[0] = add_fragments(aligned_accum_fragment0[0], aligned_accum_fragment0[i]); + aligned_accum_fragment1[0] = add_fragments(aligned_accum_fragment1[0], aligned_accum_fragment1[i]); + } + + shared_load_iterator0_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); + shared_load_iterator1_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment[3]; + + apply_output_operator_(output_fragment, + output_op0, output_op1, output_op2, + aligned_accum_fragment0[0], aligned_accum_fragment1[0], + source_fragment); + + + // + // Store the final result + // + + if (kStoreD0) { + dest0.store(output_fragment[0]); + ++dest0; + } + if (kStoreD1) { + dest1.store(output_fragment[1]); + ++dest1; + } + if (writeToD2) { + dest2.store(output_fragment[2]); + ++dest2; + } + } + } + +private: + + static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1."); + + template + struct acc2smem_source_needed; + + template + struct acc2smem_source_needed> { + template + CUTLASS_DEVICE + static void helper(AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator &warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + accum_fragment_iterator.load(accum_fragment); + warp_tile_iterator.store(accum_fragment); + } + + CUTLASS_DEVICE + static void push(size_t pos, + AccumulatorFragmentIterator const &iterator_begin, + WarpTileIterator &warp_tile_iterator) { + int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; + } + }; + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_( + typename OutputTileIterator::Fragment (&output_fragment)[3], + OutputOp0 const &output_op0, + OutputOp1 const &output_op1, + OutputOp2 const &output_op2, + typename SharedLoadIterator::Fragment const& aligned_accum_fragment0, + typename SharedLoadIterator::Fragment const& aligned_accum_fragment1, + typename OutputTileIterator::Fragment const (&source_fragment)[2]) { + + OutputAccessType* output_frag_ptr[3] = { + reinterpret_cast(&output_fragment[0]), + reinterpret_cast(&output_fragment[1]), + reinterpret_cast(&output_fragment[2]) + }; + + AccumulatorAccessType const *compute_frag_ptr[2] = { + reinterpret_cast(&aligned_accum_fragment0), + reinterpret_cast(&aligned_accum_fragment1) + }; + + OutputAccessType const *source_frag_ptr[2] = { + reinterpret_cast(&source_fragment[0]), + reinterpret_cast(&source_fragment[1]) + }; + + int const kOutputOpIterations = + OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + + // Call the output operators + output_frag_ptr[0][i] = output_op0(compute_frag_ptr[0][i], source_frag_ptr[0][i]); + output_frag_ptr[1][i] = output_op1(compute_frag_ptr[1][i], source_frag_ptr[1][i]); + output_frag_ptr[2][i] = output_op2(output_frag_ptr[0][i], output_frag_ptr[1][i]); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/43_dual_gemm/threadblock/dual_mma_base.h b/examples/43_dual_gemm/threadblock/dual_mma_base.h new file mode 100644 index 000000000..975eb137b --- /dev/null +++ b/examples/43_dual_gemm/threadblock/dual_mma_base.h @@ -0,0 +1,218 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/threadblock/mma_base.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class DualMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + static_assert(kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + static_assert((kWarpGemmIterations % 2) == 0, + "Inner loop iteration must be an even number."); + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape; + + /// Shape of the B matrix operand in shared memory + using ShapeB = + MatrixShape; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B0; + AlignedBuffer operand_B1; + + public: + + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B0_ref() { + return TensorRefB{operand_B0.data(), LayoutB()}; + } + CUTLASS_HOST_DEVICE + TensorRefB operand_B1_ref() { + return TensorRefB{operand_B1.data(), LayoutB()}; + } + }; + + protected: + + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B0_; + typename Operator::IteratorB warp_tile_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + DualMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B0_(shared_storage.operand_B0_ref(), lane_idx), + warp_tile_iterator_B1_(shared_storage.operand_B1_ref(), lane_idx) { + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/43_dual_gemm/threadblock/dual_mma_multistage.h b/examples/43_dual_gemm/threadblock/dual_mma_multistage.h new file mode 100644 index 000000000..1a84aa686 --- /dev/null +++ b/examples/43_dual_gemm/threadblock/dual_mma_multistage.h @@ -0,0 +1,760 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/threadblock/mma_base.h" +#include "dual_mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Used for partial specialization + typename Enable = bool> +class DualMmaMultistage : + public DualMmaBase { +public: + ///< Base class + using Base = DualMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + + private: + + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B0_; + SmemIteratorB smem_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + DualMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B0_(shared_storage.operand_B0_ref(), thread_idx), + smem_iterator_B1_(shared_storage.operand_B1_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B0_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + this->warp_tile_iterator_B1_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B0, IteratorB &iterator_B1, + int group_start_A = 0, int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B0.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + iterator_B1.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B0_.set_iteration_index(group_start_B); + this->smem_iterator_B1_.set_iteration_index(group_start_B); + + // Async Copy for operand B0 + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B0_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B0.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B0.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B0.valid()); + } + + ++iterator_B0; + } + ++this->smem_iterator_B0_; + } + } + // Async Copy for operand B1 + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B1.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B1.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B1.valid()); + } + + ++iterator_B1; + } + ++this->smem_iterator_B1_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum0, + FragmentC &accum1, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B0, + IteratorB iterator_B1, + ///< initial value of accumulator + FragmentC const &src_accum0, + FragmentC const &src_accum1 + ) { + + // + // Prologue + // + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B0.clear_mask(gemm_k_iterations == 0); + iterator_B1.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B0.set_iteration_index(0); + iterator_B1.set_iteration_index(0); + this->smem_iterator_B0_.set_iteration_index(0); + this->smem_iterator_B1_.set_iteration_index(0); + + // Async Copy for operand B0 + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B0_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B0.get(), iterator_B0.valid()); + + ++iterator_B0; + } + + ++this->smem_iterator_B0_; + } + // Async Copy for operand B1 + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); + + ++iterator_B1; + } + + ++this->smem_iterator_B1_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B0.add_tile_offset({1, 0}); + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B0_.add_tile_offset({1, 0}); + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum0 = src_accum0; + accum1 = src_accum1; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels + // so that all accumulator elements outside the GEMM footprint are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + typename IteratorB::AccessType zero_B; + zero_B.clear(); + + /// Iterator to write threadblock-scoped tile of B0 operand to shared memory + SmemIteratorB last_smem_iterator_B0(this->smem_iterator_B0_); + last_smem_iterator_B0.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + last_smem_iterator_B0.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B0; + } + /// Iterator to write threadblock-scoped tile of B1 operand to shared memory + SmemIteratorB last_smem_iterator_B1(this->smem_iterator_B1_); + last_smem_iterator_B1.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + last_smem_iterator_B1.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B1; + } + } + + // Waits until stages up to the previous (kStages-2)th stage have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[2]; + WarpLoadedFragmentB warp_loaded_frag_B0[2]; + WarpLoadedFragmentB warp_loaded_frag_B1[2]; + WarpTransformedFragmentA warp_transformed_frag_A[2]; + WarpTransformedFragmentB warp_transformed_frag_B0[2]; + WarpTransformedFragmentB warp_transformed_frag_B1[2]; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B0_.set_kgroup_index(0); + this->warp_tile_iterator_B1_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B0_; + ++this->warp_tile_iterator_B1_; + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B0.clear_mask(gemm_k_iterations == 0); + iterator_B1.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B0[0], + warp_loaded_frag_A[0], warp_loaded_frag_B0[0]); + warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B1[0], + warp_loaded_frag_A[0], warp_loaded_frag_B1[0]); + + // tf32x3 kernels use staging accumulation. warp_mma uses a temporary + // accumulator and this temporary accumulator is added to the final + // accumulator once in every mainloop iteration. + plus plus_accum; + + FragmentC tmp_accum0, tmp_accum1; + + if (platform::is_same::value + || platform::is_same::value) { + + tmp_accum0.clear(); + tmp_accum1.clear(); + } + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B0_; + ++this->warp_tile_iterator_B1_; + + if (warp_mma_k > 0) { + warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B0[warp_mma_k % 2]); + warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B1[warp_mma_k % 2]); + } + + if (platform::is_same::value + || platform::is_same::value) { + + warp_mma( + tmp_accum0, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + tmp_accum0 + ); + warp_mma( + tmp_accum1, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + tmp_accum1 + ); + + if (warp_mma_k == 0) { + accum0 = plus_accum(accum0, tmp_accum0); + accum1 = plus_accum(accum1, tmp_accum1); + tmp_accum0.clear(); + tmp_accum1.clear(); + } + } else { + warp_mma( + accum0, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + accum0 + ); + warp_mma( + accum1, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + accum1 + ); + } + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B0, iterator_B1, group_start_iteration_A, + group_start_iteration_B); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B0, iterator_B1, group_start_iteration_A, + group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until stages up to the previous (kStages-2)th stage have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B0.add_tile_offset({1, 0}); + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B0_.add_tile_offset({1, 0}); + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B0_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + this->warp_tile_iterator_B1_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B0.clear_mask(gemm_k_iterations == 0); + iterator_B1.clear_mask(gemm_k_iterations == 0); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations) { + warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B0[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + } + + } + + if (platform::is_same::value + || platform::is_same::value) { + accum0 = plus_accum(accum0, tmp_accum0); + accum1 = plus_accum(accum1, tmp_accum1); + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 089ee9612..bd75a7423 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -121,6 +121,7 @@ foreach(EXAMPLE 39_gemm_permute 41_multi_head_attention 42_fused_multi_head_attention + 43_dual_gemm ) add_subdirectory(${EXAMPLE}) diff --git a/include/cutlass/gemm/threadblock/mma_base.h b/include/cutlass/gemm/threadblock/mma_base.h index e2cb4a477..13414d55e 100644 --- a/include/cutlass/gemm/threadblock/mma_base.h +++ b/include/cutlass/gemm/threadblock/mma_base.h @@ -34,6 +34,7 @@ #pragma once +#include "cutlass/tensor_ref.h" #include "cutlass/aligned_buffer.h" #include "cutlass/arch/memory.h" #include "cutlass/array.h"