diff --git a/examples/45_dual_gemm/device/dual_gemm.h b/examples/45_dual_gemm/device/dual_gemm.h index 491888bef..71f7973ed 100644 --- a/examples/45_dual_gemm/device/dual_gemm.h +++ b/examples/45_dual_gemm/device/dual_gemm.h @@ -52,6 +52,7 @@ D2 = element_wise(D0, D1) #include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" #include "../kernel/dual_gemm.h" +#include "../dual_gemm_common.h" //////////////////////////////////////////////////////////////////////////////// @@ -68,8 +69,10 @@ template < typename LayoutA_, /// Element type for B matrix operand typename ElementB_, - /// Layout type for B matrix operand - typename LayoutB_, + /// Layout type for B0 matrix operand + typename LayoutB0_, + /// Layout type for B1 matrix operand + typename LayoutB1_, /// Element type for C and D matrix operands typename ElementC_, /// Layout type for C and D matrix operands @@ -119,8 +122,10 @@ class DualGemm { using LayoutA = LayoutA_; using TensorRefA = TensorRef; using ElementB = ElementB_; - using LayoutB = LayoutB_; - using TensorRefB = TensorRef; + using LayoutB0 = LayoutB0_; + using LayoutB1 = LayoutB1_; + using TensorRefB0 = TensorRef; + using TensorRefB1 = TensorRef; using ElementC = ElementC_; using LayoutC = LayoutC_; using TensorRefC = TensorRef; @@ -151,23 +156,31 @@ class DualGemm { /// Define the threadblock-scoped matrix multiply-accumulate static_assert(ArchTag::kMinComputeCapability >= 80, "Only multistage is implemented"); static_assert(kStages >= 3, "Only multistage is implemented"); - using Mma = typename cutlass::gemm::threadblock::DefaultMma< - ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + using Mma0 = typename cutlass::gemm::threadblock::DefaultMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB0, kAlignmentB, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, + ThreadblockShape, WarpShape, + InstructionShape, Stages, Operator>::ThreadblockMma; + using Mma1 = typename cutlass::gemm::threadblock::DefaultMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB1, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, Stages, Operator>::ThreadblockMma; using DualMma = threadblock::DualMmaMultistage< - typename Mma::Shape, - typename Mma::IteratorA, - typename Mma::SmemIteratorA, - Mma::kCacheOpA, - typename Mma::IteratorB, - typename Mma::SmemIteratorB, - Mma::kCacheOpB, - typename Mma::ElementC, - typename Mma::LayoutC, - typename Mma::Policy, - Mma::kStages, + typename Mma0::Shape, + typename Mma0::IteratorA, + typename Mma0::SmemIteratorA, + Mma0::kCacheOpA, + typename Mma0::IteratorB, + typename Mma0::SmemIteratorB, + Mma0::kCacheOpB, + typename Mma1::IteratorB, + typename Mma1::SmemIteratorB, + typename Mma0::ElementC, + typename Mma0::LayoutC, + typename Mma0::Policy, + typename Mma1::Policy, + Mma0::kStages, SharedMemoryClearOption::kNone >; @@ -176,11 +189,11 @@ class DualGemm { /// Define the epilogue using Epilogue0 = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, typename DualMma::Operator, kPartitionsK, EpilogueOutputOp0, + ThreadblockShape, typename DualMma::Operator0, kPartitionsK, EpilogueOutputOp0, EpilogueOutputOp0::kCount>::Epilogue; using Epilogue1 = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, typename DualMma::Operator, kPartitionsK, EpilogueOutputOp1, + ThreadblockShape, typename DualMma::Operator1, kPartitionsK, EpilogueOutputOp1, EpilogueOutputOp1::kCount>::Epilogue; /// Define the kernel-level GEMM operator. @@ -197,12 +210,13 @@ class DualGemm { // Data members // + DualGemmMode mode; GemmCoord problem_size; TensorRef ref_A0; - TensorRef ref_B0; + TensorRef ref_B0; TensorRef ref_C0; TensorRef ref_D0; - TensorRef ref_B1; + TensorRef ref_B1; TensorRef ref_C1; TensorRef ref_D1; TensorRef ref_D2; @@ -211,6 +225,13 @@ class DualGemm { typename EpilogueOutputOp2::Params epilogue2; int split_k_slices; + int batch_count; + int64_t batch_stride_A; + int64_t batch_stride_B0; + int64_t batch_stride_B1; + int64_t batch_stride_C; + int64_t batch_stride_D; + // // Methods // @@ -224,12 +245,13 @@ class DualGemm { /// Constructs an Arguments structure CUTLASS_HOST_DEVICE Arguments( + DualGemmMode mode, GemmCoord problem_size_, TensorRef ref_A0_, - TensorRef ref_B0_, + TensorRef ref_B0_, TensorRef ref_C0_, TensorRef ref_D0_, - TensorRef ref_B1_, + TensorRef ref_B1_, TensorRef ref_C1_, TensorRef ref_D1_, TensorRef ref_D2_, @@ -239,8 +261,15 @@ class DualGemm { typename EpilogueOutputOp1::Params(), typename EpilogueOutputOp2::Params epilogue2_ = typename EpilogueOutputOp2::Params(), - int split_k_slices_ = 1 + int split_k_slices_ = 1, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B0 = 0, + int64_t batch_stride_B1 = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0 ): + mode(mode), problem_size(problem_size_), ref_A0(ref_A0_), ref_B0(ref_B0_), @@ -253,7 +282,13 @@ class DualGemm { epilogue0(epilogue0_), epilogue1(epilogue1_), epilogue2(epilogue2_), - split_k_slices(split_k_slices_) { + split_k_slices(split_k_slices_), + batch_count(batch_count), + batch_stride_A(batch_stride_A), + batch_stride_B0(batch_stride_B0), + batch_stride_B1(batch_stride_B1), + batch_stride_C(batch_stride_C), + batch_stride_D(batch_stride_D) { } }; @@ -271,6 +306,9 @@ public: /// Determines whether the GEMM can execute the given problem. static Status can_implement(Arguments const &args) { + if (args.mode == DualGemmMode::kBatched && kSplitKSerial) { + return Status::kErrorInvalidProblem; + } if (!kSplitKSerial && args.split_k_slices > 1) { return Status::kErrorInvalidProblem; } @@ -304,17 +342,15 @@ public: static size_t get_workspace_size(Arguments const &args) { size_t bytes = 0; - - // Determine grid shape - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( - args.problem_size, - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, - args.split_k_slices); if (kSplitKSerial && args.split_k_slices > 1) { + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); } @@ -331,7 +367,7 @@ public: cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, - args.split_k_slices); + args.mode == DualGemmMode::kBatched ? args.batch_count : args.split_k_slices); if (kSplitKSerial) { if (args.split_k_slices > 1) { @@ -357,6 +393,7 @@ public: // Initialize the Params structure params_ = typename DualGemmKernel::Params{ + args.mode, args.problem_size, grid_shape, args.ref_A0.non_const_ref(), @@ -371,6 +408,11 @@ public: args.epilogue1, args.epilogue2, reinterpret_cast(workspace), + args.batch_stride_A, + args.batch_stride_B0, + args.batch_stride_B1, + args.batch_stride_C, + args.batch_stride_D, }; return Status::kSuccess; diff --git a/examples/45_dual_gemm/dual_gemm.cu b/examples/45_dual_gemm/dual_gemm.cu index 15974e033..75ef15020 100644 --- a/examples/45_dual_gemm/dual_gemm.cu +++ b/examples/45_dual_gemm/dual_gemm.cu @@ -43,8 +43,6 @@ D2 = element_wise(D0, D1) D0 and D1 will be optionally stored in gmem (`kStoreD0` / `kStoreD1`) */ -// #define IS_PROFILING - #include #include "cutlass/cutlass.h" @@ -66,10 +64,12 @@ D2 = element_wise(D0, D1) //////////////////////////////////////////////////////////////////////////////// cutlass::gemm::GemmCoord problem_size(4096, 4096, 8192); +cutlass::gemm::GemmCoord batch_problem_size(321, 256, 512); constexpr int kStages = 3; constexpr bool kSplitKSerial = false; constexpr bool kUseBias = true; +constexpr int kBatchCount = 37; #if 0 @@ -165,7 +165,16 @@ bool run_nonfused_gemm_f16_sm80() { NonFusedDualGemmRun nonFusedGemm; std::cout << "Running Non-fused GEMMs FP16 TN GEMMs...\n"; - bool pass = nonFusedGemm.run(problem_size, alpha0, beta0, alpha1, beta1); + + bool pass = nonFusedGemm.run( + problem_size, + alpha0, + beta0, + alpha1, + beta1, + true /* is_profiling */ + ); + if(pass) std::cout << "Pass\n"; else @@ -215,6 +224,7 @@ bool run_fused_gemm_f16_sm80_shmem() { cutlass::layout::RowMajor, ElementOperandB, cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, @@ -236,24 +246,212 @@ bool run_fused_gemm_f16_sm80_shmem() { DualFusedGemmRun fusedGemm; std::cout << "Running Fused FP16 TN GEMMs + Epilogue2...\n"; - bool passed = fusedGemm.run(problem_size, alpha0, beta0, alpha1, beta1); + + bool passed = fusedGemm.run( + problem_size, + alpha0, + beta0, + alpha1, + beta1 + ); + if(passed) std::cout << "Pass\n"; else std::cout << "Fail\n"; return passed; +} +bool run_batched_fused_gemm_f16_sm80_shmem() { + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + // Optionally, we might not need intermediate GEMM outputs + constexpr bool kStoreD0 = true; + constexpr bool kStoreD1 = true; + + using DualGemm = cutlass::gemm::device::DualGemm< + ElementOperandA, + cutlass::layout::RowMajor, + ElementOperandB, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + EpilogueOutputOp2, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, + kStages, + kStoreD0, + kStoreD1, + kSplitKSerial + >; + + DualFusedGemmRun fusedGemm; + + std::cout << "Running Batched Fused FP16 TN GEMMs + Epilogue2...\n"; + + bool passed = fusedGemm.run( + batch_problem_size, + alpha0, + beta0, + alpha1, + beta1, + kBatchCount, + false, /* broadcast_b1 */ + false /* is_profiling */ + ); + + if(passed) + std::cout << "Pass\n"; + else + std::cout << "Fail\n"; + + return passed; +} + +bool run_broadcast_fused_gemm_f16_sm80_shmem() { + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + // Optionally, we might not need intermediate GEMM outputs + constexpr bool kStoreD0 = true; + constexpr bool kStoreD1 = true; + + using DualGemm = cutlass::gemm::device::DualGemm< + ElementOperandA, + cutlass::layout::RowMajor, + ElementOperandB, + // different LayoutB0 and B1 + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + EpilogueOutputOp2, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, + kStages, + kStoreD0, + kStoreD1, + kSplitKSerial + >; + + DualFusedGemmRun fusedGemm; + + std::cout << "Running Broadcast Fused FP16 TN GEMMs + Epilogue2...\n"; + + bool passed = fusedGemm.run( + problem_size, + alpha0, + beta0, + alpha1, + beta1, + 1, /* batch_count */ + true, /* broadcast_b1 */ + true /* is_profiling */ + ); + + if(passed) + std::cout << "Pass\n"; + else + std::cout << "Fail\n"; + + return passed; +} + +bool run_batched_broadcast_fused_gemm_f16_sm80_shmem() { + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + // Optionally, we might not need intermediate GEMM outputs + constexpr bool kStoreD0 = true; + constexpr bool kStoreD1 = true; + + using DualGemm = cutlass::gemm::device::DualGemm< + ElementOperandA, + cutlass::layout::RowMajor, + ElementOperandB, + // different LayoutB0 and B1 + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + EpilogueOutputOp2, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, + kStages, + kStoreD0, + kStoreD1, + kSplitKSerial + >; + + DualFusedGemmRun fusedGemm; + + std::cout << "Running Batch Broadcast Fused FP16 TN GEMMs + Epilogue2...\n"; + + bool passed = fusedGemm.run( + batch_problem_size, + alpha0, + beta0, + alpha1, + beta1, + kBatchCount, + true, /* broadcast_b1 */ + false /* is_profiling */ + ); + + if(passed) + std::cout << "Pass\n"; + else + std::cout << "Fail\n"; + + return passed; } int main() { std::vectorfuncs = { &run_nonfused_gemm_f16_sm80, - &run_fused_gemm_f16_sm80_shmem + &run_fused_gemm_f16_sm80_shmem, + &run_batched_fused_gemm_f16_sm80_shmem, + &run_broadcast_fused_gemm_f16_sm80_shmem, + &run_batched_broadcast_fused_gemm_f16_sm80_shmem }; - std::string test_name = "dual-gemm f16 bias=" + std::to_string(kUseBias) + " split_k_serial=" + std::to_string(kSplitKSerial); + std::string test_name = ( + "dual-gemm f16 bias=" + + std::to_string(kUseBias) + + " split_k_serial=" + + std::to_string(kSplitKSerial) + + " batch_count=" + + std::to_string(kBatchCount) + ); + return testRun(80, funcs, test_name); } diff --git a/examples/45_dual_gemm/dual_gemm_common.h b/examples/45_dual_gemm/dual_gemm_common.h new file mode 100644 index 000000000..615b480ef --- /dev/null +++ b/examples/45_dual_gemm/dual_gemm_common.h @@ -0,0 +1,52 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines common types used for all DualGemm operators. +*/ +#pragma once + +namespace cutlass { +namespace gemm { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +enum class DualGemmMode { + kGemm, + kBatched, + kInvalid +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/45_dual_gemm/dual_gemm_run.h b/examples/45_dual_gemm/dual_gemm_run.h index 63ca2ac84..d6a52d58e 100644 --- a/examples/45_dual_gemm/dual_gemm_run.h +++ b/examples/45_dual_gemm/dual_gemm_run.h @@ -33,6 +33,7 @@ #include #include #include +#include #include "cutlass/util/host_tensor.h" #include "cutlass/util/tensor_view_io.h" @@ -44,6 +45,10 @@ #include "cutlass/util/reference/device/gemm.h" #include "cutlass/util/reference/device/tensor_relu.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/device/gemm_universal.h" + +#include "dual_gemm_common.h" #include "helper.h" #define CHECK_GT(val1, val2) \ @@ -205,6 +210,7 @@ struct NonFusedDualGemmRun ElementCompute beta0 = ElementCompute(0), ElementCompute alpha1 = ElementCompute(1), ElementCompute beta1 = ElementCompute(0), + bool is_profiling = true, bool relu = false, int warm_ups = 1, int runs = 100) { @@ -335,40 +341,41 @@ struct NonFusedDualGemmRun status = gemm_op_1(); CUTLASS_CHECK(status); } -#ifdef IS_PROFILING - return true; -#endif - // - // Run the GEMM - // - cudaEvent_t start, stop1, stop2; - cudaEventCreate(&start); - cudaEventCreate(&stop1); - cudaEventCreate(&stop2); - cudaEventRecord(start); + if (is_profiling) { + // + // Profile the GEMM + // - for(int i = 0; i < runs; i++) { - status = gemm_op_0(); - - CUTLASS_CHECK(status); + cudaEvent_t start, stop1, stop2; + cudaEventCreate(&start); + cudaEventCreate(&stop1); + cudaEventCreate(&stop2); + + cudaEventRecord(start); + + for(int i = 0; i < runs; i++) { + status = gemm_op_0(); + + CUTLASS_CHECK(status); + } + cudaEventRecord(stop1); + for(int i = 0; i < runs; i++) { + status = gemm_op_1(); + + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop2); + cudaDeviceSynchronize(); + float gemm0Time, gemm1Time, totalTime; + cudaEventElapsedTime(&gemm0Time, start, stop1); + cudaEventElapsedTime(&gemm1Time, stop1, stop2); + cudaEventElapsedTime(&totalTime, start, stop2); + std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n"; + std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n"; + std::cout << "Non-fusion GEMM only time " << totalTime / (float)runs << " ms\n"; } - cudaEventRecord(stop1); - for(int i = 0; i < runs; i++) { - status = gemm_op_1(); - - CUTLASS_CHECK(status); - } - - cudaEventRecord(stop2); - cudaDeviceSynchronize(); - float gemm0Time, gemm1Time, totalTime; - cudaEventElapsedTime(&gemm0Time, start, stop1); - cudaEventElapsedTime(&gemm1Time, stop1, stop2); - cudaEventElapsedTime(&totalTime, start, stop2); - std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n"; - std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n"; - std::cout << "Non-fusion GEMM only time " << totalTime / (float)runs << " ms\n"; tensor_D0.sync_host(); tensor_D1.sync_host(); @@ -543,6 +550,9 @@ struct DualFusedGemmRun ElementCompute beta0 = ElementCompute(1), ElementCompute alpha1 = ElementCompute(1), ElementCompute beta1 = ElementCompute(1), + int batch_count = 1, + bool broadcast_b1 = false, + bool is_profiling = true, bool relu = false, int warm_ups = 1, int runs = 100) { @@ -553,55 +563,91 @@ struct DualFusedGemmRun cutlass::HostTensor< typename DualGemm::ElementA, - typename DualGemm::LayoutA> tensor_A0(problem_size.mk()); + typename DualGemm::LayoutA> tensor_A0( + std::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.k()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.k())); cutlass::HostTensor< typename DualGemm::ElementB, - typename DualGemm::LayoutB> tensor_B0(problem_size.kn()); + typename DualGemm::LayoutB0> tensor_B0( + std::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.k(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementC, - typename DualGemm::LayoutC> tensor_C0(problem_size.mn()); + typename DualGemm::LayoutC> tensor_C0( + std::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementC, - typename DualGemm::LayoutScaleBias> tensor_Bias0({1, problem_size.n()}); + typename DualGemm::LayoutScaleBias> tensor_Bias0({batch_count, problem_size.n()}); cutlass::HostTensor< typename DualGemm::ElementC, - typename DualGemm::LayoutC> tensor_D0(problem_size.mn()); + typename DualGemm::LayoutC> tensor_D0( + std::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementC, - typename DualGemm::LayoutC> reference_D0(problem_size.mn()); + typename DualGemm::LayoutC> reference_D0( + std::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementB, - typename DualGemm::LayoutB> tensor_B1(problem_size.kn()); + typename DualGemm::LayoutB1> tensor_B1( + std::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.k(), batch_count * problem_size.n())); + if (broadcast_b1) { + tensor_B1.resize({problem_size.k(), batch_count}); + } cutlass::HostTensor< typename DualGemm::ElementC, - typename DualGemm::LayoutC> tensor_C1(problem_size.mn()); + typename DualGemm::LayoutC> tensor_C1( + std::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementC, - typename DualGemm::LayoutScaleBias> tensor_Bias1({1, problem_size.n()}); + typename DualGemm::LayoutScaleBias> tensor_Bias1({batch_count, problem_size.n()}); cutlass::HostTensor< typename DualGemm::ElementC, - typename DualGemm::LayoutC> tensor_D1(problem_size.mn()); + typename DualGemm::LayoutC> tensor_D1( + std::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementC, - typename DualGemm::LayoutC> tensor_D2(problem_size.mn()); + typename DualGemm::LayoutC> tensor_D2( + std::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementC, - typename DualGemm::LayoutC> reference_D1(problem_size.mn()); + typename DualGemm::LayoutC> reference_D1( + std::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementC, - typename DualGemm::LayoutC> reference_D2(problem_size.mn()); + typename DualGemm::LayoutC> reference_D2( + std::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2118)); @@ -638,6 +684,20 @@ struct DualFusedGemmRun reference_D1.sync_device(); reference_D2.sync_device(); + // + // Batch strides (irrelevant when batch_count == 1) + // + + int64_t batch_stride_A = problem_size.m() * problem_size.k(); + int64_t batch_stride_B0 = problem_size.k() * problem_size.n(); + int64_t batch_stride_B1 = problem_size.k() * problem_size.n(); + if (broadcast_b1) { + // B1 is a (column) vector + batch_stride_B1 = problem_size.k(); + } + int64_t batch_stride_Bias = problem_size.n(); + int64_t batch_stride_D = problem_size.m() * problem_size.n(); + // // Initialize the GEMM operator // @@ -652,21 +712,36 @@ struct DualFusedGemmRun ref_B1 = {tensor_Bias1.device_data(), typename DualGemm::LayoutC::Stride(0)}; } typename DualGemm::Arguments arguments{ + (batch_count > 1 ? + cutlass::gemm::DualGemmMode::kBatched : + cutlass::gemm::DualGemmMode::kGemm), problem_size, tensor_A0.device_ref(), tensor_B0.device_ref(), ref_B0, DualGemm::kStoreD0 ? tensor_D0.device_ref() : nullptr_ref, - tensor_B1.device_ref(), + (broadcast_b1 ? + typename DualGemm::TensorRefB1(tensor_B1.device_data(), 0) : + tensor_B1.device_ref()), ref_B1, DualGemm::kStoreD1 ? tensor_D1.device_ref() : nullptr_ref, tensor_D2.device_ref(), {alpha0, beta0}, {alpha1, beta1}, {}, - split_k_slices + split_k_slices, + batch_count, + batch_stride_A, + batch_stride_B0, + batch_stride_B1, + batch_stride_Bias, + batch_stride_D, }; + // + // Run the GEMM + // + DualGemm b2b_gemm_op; cutlass::device_memory::allocation workspace(b2b_gemm_op.get_workspace_size(arguments)); @@ -684,31 +759,29 @@ struct DualFusedGemmRun CUTLASS_CHECK(status); } -#ifdef IS_PROFILING - return true; -#endif - // - // Run the GEMM - // + if (is_profiling) { + // + // Profile the GEMM + // - cudaEvent_t start, stop; - cudaEventCreate(&start); - cudaEventCreate(&stop); + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); - cudaEventRecord(start); + cudaEventRecord(start); - for(int i = 0; i < runs; i++) { - status = b2b_gemm_op(); + for(int i = 0; i < runs; i++) { + status = b2b_gemm_op(); + CUTLASS_CHECK(status); + } - CUTLASS_CHECK(status); + cudaEventRecord(stop); + cudaDeviceSynchronize(); + float gemmTime; + cudaEventElapsedTime(&gemmTime, start, stop); + std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n"; } - cudaEventRecord(stop); - cudaDeviceSynchronize(); - float gemmTime; - cudaEventElapsedTime(&gemmTime, start, stop); - std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n"; - tensor_D0.sync_host(); tensor_D1.sync_host(); tensor_D2.sync_host(); @@ -717,45 +790,81 @@ struct DualFusedGemmRun // Verify // - cutlass::reference::device::Gemm< - typename DualGemm::ElementA, typename DualGemm::LayoutA, - typename DualGemm::ElementB, typename DualGemm::LayoutB, - typename DualGemm::ElementC, typename DualGemm::LayoutC, - ElementAccumulator, ElementAccumulator> - reference_gemm_0; + using GemmUniversal0 = cutlass::gemm::device::GemmUniversal< + typename DualGemm::ElementA, typename DualGemm::LayoutA, + typename DualGemm::ElementB, typename DualGemm::LayoutB0, + typename DualGemm::ElementC, typename DualGemm::LayoutC, + ElementAccumulator + >; - cutlass::reference::device::Gemm< - typename DualGemm::ElementA, typename DualGemm::LayoutA, - typename DualGemm::ElementB, typename DualGemm::LayoutB, - typename DualGemm::ElementC, typename DualGemm::LayoutC, ElementCompute, - ElementAccumulator, typename DualGemm::Operator> - reference_gemm_1; + GemmUniversal0 reference_gemm0; - reference_gemm_0( + typename GemmUniversal0::Arguments args0 { + (batch_count > 1 ? + cutlass::gemm::GemmUniversalMode::kBatched : + cutlass::gemm::GemmUniversalMode::kGemm), problem_size, - alpha0, - tensor_A0.device_ref(), - tensor_B0.device_ref(), - beta0, - {tensor_Bias0.device_data(), typename DualGemm::LayoutC::Stride(0)}, - reference_D0.device_ref() - ); + batch_count, + {alpha0, beta0}, + tensor_A0.device_data(), + tensor_B0.device_data(), + tensor_Bias0.device_data(), + reference_D0.device_data(), + batch_stride_A, + batch_stride_B0, + batch_stride_Bias, + batch_stride_D, + tensor_A0.stride(0), + tensor_B0.stride(0), + 0, // zero stride for the bias vector + reference_D0.stride(0), + }; + + status = reference_gemm0.can_implement(args0); + CUTLASS_CHECK(status); + status = reference_gemm0(args0); + CUTLASS_CHECK(status); + + using GemmUniversal1 = cutlass::gemm::device::GemmUniversal< + typename DualGemm::ElementA, typename DualGemm::LayoutA, + typename DualGemm::ElementB, typename DualGemm::LayoutB1, + typename DualGemm::ElementC, typename DualGemm::LayoutC, + ElementAccumulator + >; + + GemmUniversal1 reference_gemm1; + + typename GemmUniversal1::Arguments args1 { + (batch_count > 1 ? + cutlass::gemm::GemmUniversalMode::kBatched : + cutlass::gemm::GemmUniversalMode::kGemm), + problem_size, + batch_count, + {alpha1, beta1}, + tensor_A0.device_data(), + tensor_B1.device_data(), + tensor_Bias1.device_data(), + reference_D1.device_data(), + batch_stride_A, + batch_stride_B1, + batch_stride_Bias, + batch_stride_D, + tensor_A0.stride(0), + (broadcast_b1 ? 0 : tensor_B1.stride(0)), + 0, // zero stride for the bias vector + reference_D1.stride(0), + }; + + status = reference_gemm1.can_implement(args1); + CUTLASS_CHECK(status); + status = reference_gemm1(args1); + CUTLASS_CHECK(status); + if(relu) { cutlass::reference::device::TensorReLu(reference_D0.device_view()); - } - - reference_gemm_1( - problem_size, - alpha1, - tensor_A0.device_ref(), - tensor_B1.device_ref(), - beta1, - {tensor_Bias1.device_data(), typename DualGemm::LayoutC::Stride(0)}, - reference_D1.device_ref() - ); - if(relu) { cutlass::reference::device::TensorReLu(reference_D1.device_view()); } + TensorEpilogueForEach(reference_D0.device_view(), reference_D1.device_view(), reference_D2.device_view()); cudaDeviceSynchronize(); reference_D0.sync_host(); @@ -793,7 +902,6 @@ struct DualFusedGemmRun bool passed = passed_out0 && passed_out1 && passed_out2; if (!passed) { - std::stringstream fname; fname << "error_DualGemm_device_fused.txt"; diff --git a/examples/45_dual_gemm/kernel/dual_gemm.h b/examples/45_dual_gemm/kernel/dual_gemm.h index 4cbddaa77..56ed9e7ea 100644 --- a/examples/45_dual_gemm/kernel/dual_gemm.h +++ b/examples/45_dual_gemm/kernel/dual_gemm.h @@ -42,6 +42,7 @@ #include "../threadblock/dual_mma_multistage.h" #include "../threadblock/dual_epilogue.h" +#include "../dual_gemm_common.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -92,6 +93,10 @@ struct DualGemm { true // IterationsUnroll >; + using ElementA = typename DualMma::IteratorA::Element; + using ElementB = typename DualMma::IteratorB0::Element; + using ElementC = typename DualEpilogue::OutputTileIterator::Element; + static bool const kSplitKSerial = SplitKSerial; static_assert(!kSplitKSerial || (kStoreD0 && kStoreD1), "Split-K serial requires buffers for D0/D1 for reduction"); @@ -102,6 +107,7 @@ struct DualGemm { /// Parameters structure struct Params { + DualGemmMode mode; cutlass::gemm::GemmCoord problem_size; cutlass::gemm::GemmCoord grid_tiled_shape; int swizzle_log_tile; @@ -109,8 +115,8 @@ struct DualGemm { // Mma0 typename DualMma::IteratorA::Params params_A0; typename DualMma::IteratorA::TensorRef ref_A0; - typename DualMma::IteratorB::Params params_B0; - typename DualMma::IteratorB::TensorRef ref_B0; + typename DualMma::IteratorB0::Params params_B0; + typename DualMma::IteratorB0::TensorRef ref_B0; typename Epilogue0::OutputTileIterator::Params params_C0; typename Epilogue0::OutputTileIterator::TensorRef ref_C0; typename Epilogue0::OutputTileIterator::Params params_D0; @@ -118,8 +124,8 @@ struct DualGemm { typename OutputOp0::Params output_op_0; // Mma1 - typename DualMma::IteratorB::Params params_B1; - typename DualMma::IteratorB::TensorRef ref_B1; + typename DualMma::IteratorB1::Params params_B1; + typename DualMma::IteratorB1::TensorRef ref_B1; typename Epilogue1::OutputTileIterator::Params params_C1; typename Epilogue1::OutputTileIterator::TensorRef ref_C1; typename Epilogue1::OutputTileIterator::Params params_D1; @@ -133,6 +139,12 @@ struct DualGemm { int *semaphore; int gemm_k_size; + int64_t batch_stride_A; + int64_t batch_stride_B0; + int64_t batch_stride_B1; + int64_t batch_stride_C; + int64_t batch_stride_D; + // // Methods // @@ -142,15 +154,16 @@ struct DualGemm { CUTLASS_HOST_DEVICE Params( + DualGemmMode mode, cutlass::gemm::GemmCoord const & problem_size, cutlass::gemm::GemmCoord const & grid_tiled_shape, // Mma0: D0 = A @ B0 + C0 typename DualMma::IteratorA::TensorRef ref_A0, - typename DualMma::IteratorB::TensorRef ref_B0, + typename DualMma::IteratorB0::TensorRef ref_B0, typename Epilogue0::OutputTileIterator::TensorRef ref_C0, typename Epilogue0::OutputTileIterator::TensorRef ref_D0, // Mma1: D1 = A @ B1 + C1 - typename DualMma::IteratorB::TensorRef ref_B1, + typename DualMma::IteratorB1::TensorRef ref_B1, typename Epilogue1::OutputTileIterator::TensorRef ref_C1, typename Epilogue1::OutputTileIterator::TensorRef ref_D1, @@ -158,8 +171,14 @@ struct DualGemm { typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(), typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(), typename OutputOp2::Params output_op_2 = typename OutputOp2::Params(), - int *workspace = nullptr + int *workspace = nullptr, + int64_t batch_stride_A = 1, + int64_t batch_stride_B0 = 1, + int64_t batch_stride_B1 = 1, + int64_t batch_stride_C = 1, + int64_t batch_stride_D = 1 ): + mode(mode), problem_size(problem_size), grid_tiled_shape(grid_tiled_shape), swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), @@ -183,13 +202,18 @@ struct DualGemm { ref_D2(ref_D2), output_op_0(output_op_0), output_op_1(output_op_1), - output_op_2(output_op_2) { + output_op_2(output_op_2), + batch_stride_A(batch_stride_A), + batch_stride_B0(batch_stride_B0), + batch_stride_B1(batch_stride_B1), + batch_stride_C(batch_stride_C), + batch_stride_D(batch_stride_D) { int total_gemm_k_iterations = (problem_size.k() + DualMma::Shape::kK - 1) / DualMma::Shape::kK; int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); gemm_k_size = gemm_k_iterations * DualMma::Shape::kK; - semaphore = workspace; + semaphore = workspace; } }; @@ -210,16 +234,16 @@ struct DualGemm { static Status can_implement( cutlass::gemm::GemmCoord const & problem_size, typename DualMma::IteratorA::TensorRef ref_A0, - typename DualMma::IteratorB::TensorRef ref_B0, + typename DualMma::IteratorB0::TensorRef ref_B0, typename Epilogue0::OutputTileIterator::TensorRef ref_C0, typename Epilogue0::OutputTileIterator::TensorRef ref_D0, - typename DualMma::IteratorB::TensorRef ref_B1, + typename DualMma::IteratorB1::TensorRef ref_B1, typename Epilogue1::OutputTileIterator::TensorRef ref_C1, typename Epilogue1::OutputTileIterator::TensorRef ref_D1, typename Epilogue1::OutputTileIterator::TensorRef ref_D2) { static int const kAlignmentA = DualMma::IteratorA::AccessType::kElements; - static int const kAlignmentB = DualMma::IteratorB::AccessType::kElements; + static int const kAlignmentB = DualMma::IteratorB0::AccessType::kElements; static int const kAlignmentC = Epilogue0::OutputTileIterator::kElementsPerAccess; if (!TensorRef_aligned(ref_A0, kAlignmentA)) { @@ -273,52 +297,66 @@ struct DualGemm { return; } + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA *ptr_A0 = static_cast(params.ref_A0.data()); + ElementB *ptr_B0 = static_cast(params.ref_B0.data()); + ElementB *ptr_B1 = static_cast(params.ref_B1.data()); + + // + // Fetch pointers based on mode. + // + if (params.mode == DualGemmMode::kGemm) { + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } + else if (params.mode == DualGemmMode::kBatched) { + ptr_A0 += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B0 += threadblock_tile_offset.k() * params.batch_stride_B0; + ptr_B1 += threadblock_tile_offset.k() * params.batch_stride_B1; + } + // Compute initial location in logical coordinates cutlass::MatrixCoord tb_offset_A0{ threadblock_tile_offset.m() * DualMma::Shape::kM, - threadblock_tile_offset.k() * params.gemm_k_size, + offset_k, }; cutlass::MatrixCoord tb_offset_B0{ - threadblock_tile_offset.k() * params.gemm_k_size, + offset_k, threadblock_tile_offset.n() * DualMma::Shape::kN }; cutlass::MatrixCoord tb_offset_B1{ - threadblock_tile_offset.k() * params.gemm_k_size, + offset_k, threadblock_tile_offset.n() * DualMma::Shape::kN }; - // Problem size is a function of threadblock index in the K dimension - int problem_size_k = - (params.problem_size.k() < (threadblock_tile_offset.k() + 1) * params.gemm_k_size) ? - params.problem_size.k() : - (threadblock_tile_offset.k() + 1) * params.gemm_k_size; - - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size_k - tb_offset_A0.column() + DualMma::Shape::kK - 1) / DualMma::Shape::kK; - // Compute position within threadblock int thread_idx = threadIdx.x; // Construct iterators to A and B operands typename DualMma::IteratorA iterator_A0( params.params_A0, - params.ref_A0.data(), + ptr_A0, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A0); - typename DualMma::IteratorB iterator_B0( + typename DualMma::IteratorB0 iterator_B0( params.params_B0, - params.ref_B0.data(), + ptr_B0, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B0); - typename DualMma::IteratorB iterator_B1( + typename DualMma::IteratorB1 iterator_B1( params.params_B1, - params.ref_B1.data(), + ptr_B1, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B1); @@ -340,6 +378,9 @@ struct DualGemm { accum0.clear(); accum1.clear(); + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + DualMma::Shape::kK - 1) / DualMma::Shape::kK; + DualMma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); if (!kSplitKSerial || gemm_k_iterations > 0) { // Compute threadblock-scoped matrix multiply-add @@ -372,31 +413,46 @@ struct DualGemm { int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + ElementC *ptr_C0 = static_cast(params.ref_C0.data()); + ElementC *ptr_C1 = static_cast(params.ref_C1.data()); + ElementC *ptr_D0 = static_cast(params.ref_D0.data()); + ElementC *ptr_D1 = static_cast(params.ref_D1.data()); + ElementC *ptr_D2 = static_cast(params.ref_D2.data()); + // Construct the semaphore. Semaphore semaphore(params.semaphore + block_idx, thread_idx); - // If performing a reduction via split-K, fetch the initial synchronization - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - - // Fetch the synchronization lock initially but do not block. - semaphore.fetch(); + if (params.mode == DualGemmMode::kGemm) { + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); - // Indicate which position in a serial reduction the output operator is currently updating - output_op_0.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); - output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + // Indicate which position in a serial reduction the output operator is currently updating + output_op_0.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + } + else if (params.mode == DualGemmMode::kBatched) { + ptr_C0 += threadblock_tile_offset.k() * params.batch_stride_C; + ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C; + ptr_D0 += threadblock_tile_offset.k() * params.batch_stride_D; + ptr_D1 += threadblock_tile_offset.k() * params.batch_stride_D; + ptr_D2 += threadblock_tile_offset.k() * params.batch_stride_D; } // Tile iterator loading from source tensor. typename Epilogue0::OutputTileIterator iterator_C0( params.params_C0, - params.ref_C0.data(), + ptr_C0, params.problem_size.mn(), thread_idx, threadblock_offset ); typename Epilogue1::OutputTileIterator iterator_C1( params.params_C1, - params.ref_C1.data(), + ptr_C1, params.problem_size.mn(), thread_idx, threadblock_offset @@ -405,21 +461,21 @@ struct DualGemm { // Tile iterator writing to destination tensor. typename Epilogue0::OutputTileIterator iterator_D0( params.params_D0, - params.ref_D0.data(), + ptr_D0, params.problem_size.mn(), thread_idx, threadblock_offset ); typename Epilogue1::OutputTileIterator iterator_D1( params.params_D1, - params.ref_D1.data(), + ptr_D1, params.problem_size.mn(), thread_idx, threadblock_offset ); typename Epilogue1::OutputTileIterator iterator_D2( params.params_D2, - params.ref_D2.data(), + ptr_D2, params.problem_size.mn(), thread_idx, threadblock_offset diff --git a/examples/45_dual_gemm/threadblock/dual_mma_base.h b/examples/45_dual_gemm/threadblock/dual_mma_base.h index 10563e704..7031781ad 100644 --- a/examples/45_dual_gemm/threadblock/dual_mma_base.h +++ b/examples/45_dual_gemm/threadblock/dual_mma_base.h @@ -58,7 +58,9 @@ template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, + typename Policy0_, + /// B1-specific version of the policy (concept: MmaPolicy) + typename Policy1_, /// Number of stages, int Stages, /// Used for partial specialization @@ -69,18 +71,20 @@ class DualMmaBase { using Shape = Shape_; ///< Policy describing tuning details - using Policy = Policy_; + using Policy0 = Policy0_; + using Policy1 = Policy1_; // // Dependent types // /// Warp-level Mma - using Operator = typename Policy::Operator; + using Operator0 = typename Policy0::Operator; + using Operator1 = typename Policy1::Operator; /// Shape describing the overall GEMM computed from shared memory /// by each warp. - using WarpGemm = typename Policy::Operator::Shape; + using WarpGemm = typename Policy0::Operator::Shape; /// Shape describing the number of warps filling the CTA using WarpCount = GemmShape; + using TensorRefA = TensorRef; /// Tensor reference to the B operand - using TensorRefB = TensorRef; + using TensorRefB0 = TensorRef; + using TensorRefB1 = TensorRef; static_assert(kWarpGemmIterations > 1, "The pipelined structure requires at least two warp-level " @@ -119,14 +124,17 @@ class DualMmaBase { // /// Shape of the A matrix operand in shared memory - using ShapeA = MatrixShape; + Policy0::SmemPaddingA::kColumn>; /// Shape of the B matrix operand in shared memory - using ShapeB = - MatrixShape; + using ShapeB0 = + MatrixShape; + using ShapeB1 = + MatrixShape; public: // @@ -134,11 +142,11 @@ class DualMmaBase { // /// Buffer for A operand - AlignedBuffer operand_A; + AlignedBuffer operand_A; /// Buffer for B operand - AlignedBuffer operand_B0; - AlignedBuffer operand_B1; + AlignedBuffer operand_B0; + AlignedBuffer operand_B1; public: @@ -148,14 +156,20 @@ class DualMmaBase { /// Returns a layout object for the A matrix CUTLASS_DEVICE - static typename Operator::LayoutA LayoutA() { - return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + static typename Operator0::LayoutA LayoutA() { + return Operator0::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); } /// Returns a layout object for the B matrix CUTLASS_HOST_DEVICE - static typename Operator::LayoutB LayoutB() { - return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + static typename Operator0::LayoutB LayoutB0() { + return Operator0::LayoutB::packed({ShapeB0::kRow, ShapeB0::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator1::LayoutB LayoutB1() { + return Operator1::LayoutB::packed({ShapeB1::kRow, ShapeB1::kColumn}); } /// Returns a TensorRef to the A operand @@ -166,12 +180,12 @@ class DualMmaBase { /// Returns a TensorRef to the B operand CUTLASS_HOST_DEVICE - TensorRefB operand_B0_ref() { - return TensorRefB{operand_B0.data(), LayoutB()}; + TensorRefB0 operand_B0_ref() { + return TensorRefB0{operand_B0.data(), LayoutB0()}; } CUTLASS_HOST_DEVICE - TensorRefB operand_B1_ref() { - return TensorRefB{operand_B1.data(), LayoutB()}; + TensorRefB1 operand_B1_ref() { + return TensorRefB1{operand_B1.data(), LayoutB1()}; } }; @@ -182,11 +196,11 @@ class DualMmaBase { // /// Iterator to load a warp-scoped tile of A operand from shared memory - typename Operator::IteratorA warp_tile_iterator_A_; + typename Operator0::IteratorA warp_tile_iterator_A_; /// Iterator to load a warp-scoped tile of B operand from shared memory - typename Operator::IteratorB warp_tile_iterator_B0_; - typename Operator::IteratorB warp_tile_iterator_B1_; + typename Operator0::IteratorB warp_tile_iterator_B0_; + typename Operator1::IteratorB warp_tile_iterator_B1_; public: diff --git a/examples/45_dual_gemm/threadblock/dual_mma_multistage.h b/examples/45_dual_gemm/threadblock/dual_mma_multistage.h index 7843f2b2b..c486c69b9 100644 --- a/examples/45_dual_gemm/threadblock/dual_mma_multistage.h +++ b/examples/45_dual_gemm/threadblock/dual_mma_multistage.h @@ -67,21 +67,30 @@ template < typename SmemIteratorA_, /// Cache operation for operand A cutlass::arch::CacheOperation::Kind CacheOpA, - /// Iterates over tiles of B operand in global memory + /// Iterates over tiles of B0 operand in global memory // (concept: ReadableTileIterator | ForwardTileIterator | // MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory + typename IteratorB0_, + /// Iterates over tiles of B0 operand in shared memory /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, + typename SmemIteratorB0_, /// Cache operation for operand B cutlass::arch::CacheOperation::Kind CacheOpB, + /// Iterates over tiles of B1 operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B1 operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, /// Data type of accumulator matrix typename ElementC_, /// Data type of accumulator matrix typename LayoutC_, /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, + typename Policy0_, + /// B1-specific version of the policy (concept: MmaPolicy) + typename Policy1_, /// Number of stages, int Stages, /// Use zfill or predicate for out-of-bound cp.async @@ -89,25 +98,29 @@ template < /// Used for partial specialization typename Enable = bool> class DualMmaMultistage : - public DualMmaBase { + public DualMmaBase { public: ///< Base class - using Base = DualMmaBase; + using Base = DualMmaBase; ///< Size of the Gemm problem - concept: gemm::GemmShape<> using Shape = Shape_; ///< Iterates over tiles of A operand in global memory using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; + ///< Iterates over tiles of B0 operand in global memory + using IteratorB0 = IteratorB0_; + ///< Iterates over tiles of B1 operand in global memory + using IteratorB1 = IteratorB1_; ///< Data type of accumulator matrix using ElementC = ElementC_; ///< Layout of accumulator matrix using LayoutC = LayoutC_; ///< Policy describing tuning details - using Policy = Policy_; + using Policy0 = Policy0_; + using Policy1 = Policy1_; using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; + using SmemIteratorB0 = SmemIteratorB0_; + using SmemIteratorB1 = SmemIteratorB1_; static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; @@ -117,19 +130,21 @@ public: // /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; + using FragmentC = typename Policy0::Operator::FragmentC; /// Warp-level Mma - using Operator = typename Policy::Operator; + using Operator0 = typename Policy0::Operator; + using Operator1 = typename Policy1::Operator; /// Minimum architecture is Sm80 to support cp.async using ArchTag = arch::Sm80; /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; + static ComplexTransform const kTransformA = Operator0::kTransformA; /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; + static ComplexTransform const kTransformB0 = Operator0::kTransformB; + static ComplexTransform const kTransformB1 = Operator1::kTransformB; /// Internal structure exposed for introspection. struct Detail { @@ -140,7 +155,7 @@ public: /// Number of cp.async instructions to load one stage of operand B static int const AsyncCopyIterationsPerStageB = - IteratorB::ThreadMap::Iterations::kCount; + IteratorB0::ThreadMap::Iterations::kCount; /// Number of stages static int const kStages = Stages; @@ -156,10 +171,12 @@ public: private: - using WarpLoadedFragmentA = typename Operator::FragmentA; - using WarpLoadedFragmentB = typename Operator::FragmentB; - using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; - using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + using WarpLoadedFragmentA = typename Operator0::FragmentA; + using WarpLoadedFragmentB0 = typename Operator0::FragmentB; + using WarpLoadedFragmentB1 = typename Operator1::FragmentB; + using WarpTransformedFragmentA = typename Operator0::TransformedFragmentA; + using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB; + using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; private: @@ -171,8 +188,8 @@ public: SmemIteratorA smem_iterator_A_; /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B0_; - SmemIteratorB smem_iterator_B1_; + SmemIteratorB0 smem_iterator_B0_; + SmemIteratorB1 smem_iterator_B1_; public: @@ -215,7 +232,7 @@ public: } CUTLASS_DEVICE - void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B0, IteratorB &iterator_B1, + void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB0 &iterator_B0, IteratorB1 &iterator_B1, int group_start_A = 0, int group_start_B = 0) { iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); @@ -253,9 +270,9 @@ public: } iterator_B0.set_iteration_index(group_start_B * - IteratorB::kAccessesPerVector); + IteratorB0::kAccessesPerVector); iterator_B1.set_iteration_index(group_start_B * - IteratorB::kAccessesPerVector); + IteratorB1::kAccessesPerVector); this->smem_iterator_B0_.set_iteration_index(group_start_B); this->smem_iterator_B1_.set_iteration_index(group_start_B); @@ -263,16 +280,16 @@ public: CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( this->smem_iterator_B0_.get()); - int const kSrcBytes = sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; + int const kSrcBytes = sizeof_bits::value * + IteratorB0::ThreadMap::kElementsPerAccess / + IteratorB0::kAccessesPerVector / 8; CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) { auto gmem_ptr = iterator_B0.get(); if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { @@ -292,16 +309,16 @@ public: CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( this->smem_iterator_B1_.get()); - int const kSrcBytes = sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; + int const kSrcBytes = sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { auto gmem_ptr = iterator_B1.get(); if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { @@ -330,8 +347,8 @@ public: ///< iterator over A operand in global memory IteratorA iterator_A, ///< iterator over B operand in global memory - IteratorB iterator_B0, - IteratorB iterator_B1, + IteratorB0 iterator_B0, + IteratorB1 iterator_B1, ///< initial value of accumulator FragmentC const &src_accum0, FragmentC const &src_accum1 @@ -386,16 +403,16 @@ public: // Async Copy for operand B0 CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( this->smem_iterator_B0_.get()); CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) { int const kSrcBytes = - sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; + sizeof_bits::value * + IteratorB0::ThreadMap::kElementsPerAccess / + IteratorB0::kAccessesPerVector / 8; cutlass::arch::cp_async_zfill( dst_ptr + v, iterator_B0.get(), iterator_B0.valid()); @@ -408,16 +425,16 @@ public: // Async Copy for operand B1 CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( this->smem_iterator_B1_.get()); CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { int const kSrcBytes = - sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; + sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; cutlass::arch::cp_async_zfill( dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); @@ -473,35 +490,35 @@ public: ++last_smem_iterator_A; } - typename IteratorB::AccessType zero_B; + typename IteratorB0::AccessType zero_B; zero_B.clear(); /// Iterator to write threadblock-scoped tile of B0 operand to shared memory - SmemIteratorB last_smem_iterator_B0(this->smem_iterator_B0_); + SmemIteratorB0 last_smem_iterator_B0(this->smem_iterator_B0_); last_smem_iterator_B0.set_iteration_index(0); - // Async Copy for operand B + // Async Copy for operand B0 CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( last_smem_iterator_B0.get()); *dst_ptr = zero_B; ++last_smem_iterator_B0; } + /// Iterator to write threadblock-scoped tile of B1 operand to shared memory - SmemIteratorB last_smem_iterator_B1(this->smem_iterator_B1_); + SmemIteratorB1 last_smem_iterator_B1(this->smem_iterator_B1_); last_smem_iterator_B1.set_iteration_index(0); - // Async Copy for operand B + // Async Copy for operand B1 CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( last_smem_iterator_B1.get()); *dst_ptr = zero_B; @@ -517,13 +534,14 @@ public: // Pair of fragments used to overlap shared memory loads and math // instructions WarpLoadedFragmentA warp_loaded_frag_A[2]; - WarpLoadedFragmentB warp_loaded_frag_B0[2]; - WarpLoadedFragmentB warp_loaded_frag_B1[2]; + WarpLoadedFragmentB0 warp_loaded_frag_B0[2]; + WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; WarpTransformedFragmentA warp_transformed_frag_A[2]; - WarpTransformedFragmentB warp_transformed_frag_B0[2]; - WarpTransformedFragmentB warp_transformed_frag_B1[2]; + WarpTransformedFragmentB0 warp_transformed_frag_B0[2]; + WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; - Operator warp_mma; + Operator0 warp_mma0; + Operator1 warp_mma1; this->warp_tile_iterator_A_.set_kgroup_index(0); this->warp_tile_iterator_B0_.set_kgroup_index(0); @@ -544,10 +562,10 @@ public: int smem_write_stage_idx = Base::kStages - 1; int smem_read_stage_idx = 0; - warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B0[0], - warp_loaded_frag_A[0], warp_loaded_frag_B0[0]); - warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B1[0], - warp_loaded_frag_A[0], warp_loaded_frag_B1[0]); + warp_mma0.transform(warp_transformed_frag_A[0], warp_transformed_frag_B0[0], + warp_loaded_frag_A[0], warp_loaded_frag_B0[0]); + warp_mma1.transform(warp_transformed_frag_A[0], warp_transformed_frag_B1[0], + warp_loaded_frag_A[0], warp_loaded_frag_B1[0]); // tf32x3 kernels use staging accumulation. warp_mma uses a temporary // accumulator and this temporary accumulator is added to the final @@ -556,9 +574,9 @@ public: FragmentC tmp_accum0, tmp_accum1; - if (platform::is_same::value - || platform::is_same::value) { tmp_accum0.clear(); @@ -597,28 +615,28 @@ public: ++this->warp_tile_iterator_B1_; if (warp_mma_k > 0) { - warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], - warp_transformed_frag_B0[warp_mma_k % 2], - warp_loaded_frag_A[warp_mma_k % 2], - warp_loaded_frag_B0[warp_mma_k % 2]); - warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], - warp_transformed_frag_B1[warp_mma_k % 2], - warp_loaded_frag_A[warp_mma_k % 2], - warp_loaded_frag_B1[warp_mma_k % 2]); + warp_mma0.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B0[warp_mma_k % 2]); + warp_mma1.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B1[warp_mma_k % 2]); } - if (platform::is_same::value - || platform::is_same::value) { - warp_mma( + warp_mma0( tmp_accum0, warp_transformed_frag_A[warp_mma_k % 2], warp_transformed_frag_B0[warp_mma_k % 2], tmp_accum0 ); - warp_mma( + warp_mma1( tmp_accum1, warp_transformed_frag_A[warp_mma_k % 2], warp_transformed_frag_B1[warp_mma_k % 2], @@ -632,13 +650,13 @@ public: tmp_accum1.clear(); } } else { - warp_mma( + warp_mma0( accum0, warp_transformed_frag_A[warp_mma_k % 2], warp_transformed_frag_B0[warp_mma_k % 2], accum0 ); - warp_mma( + warp_mma1( accum1, warp_transformed_frag_A[warp_mma_k % 2], warp_transformed_frag_B1[warp_mma_k % 2], @@ -696,14 +714,14 @@ public: if (smem_read_stage_idx == (Base::kStages - 1)) { this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * + {0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations}); this->warp_tile_iterator_B0_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * + {-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations, 0}); this->warp_tile_iterator_B1_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * + {-Base::kStages * Policy1::kPartitionsK * Base::kWarpGemmIterations, 0}); smem_read_stage_idx = 0; @@ -720,22 +738,22 @@ public: // Do any conversions feeding the first stage at the end of the loop so // we can start right away on mma instructions if (warp_mma_k + 1 == Base::kWarpGemmIterations) { - warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], - warp_transformed_frag_B0[(warp_mma_k + 1) % 2], - warp_loaded_frag_A[(warp_mma_k + 1) % 2], - warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); - warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], - warp_transformed_frag_B1[(warp_mma_k + 1) % 2], - warp_loaded_frag_A[(warp_mma_k + 1) % 2], - warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + warp_mma0.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B0[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + warp_mma1.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); } } } - if (platform::is_same::value - || platform::is_same::value) { accum0 = plus_accum(accum0, tmp_accum0); accum1 = plus_accum(accum1, tmp_accum1);