From 3c995c7606e14027dff8a79973ae91f31b5d8628 Mon Sep 17 00:00:00 2001 From: Adnan Akhundov Date: Mon, 13 Feb 2023 21:27:13 +0100 Subject: [PATCH] Extend DualGemm: support batched mode + decouple B0/B1 layouts (#790) * Fix MHA kernel Summary: ATT Test Plan: Reviewers: Subscribers: Tasks: Tags: * Extend DualGemm to support batched mode (#5) Following the GemmUniversalMode::kBatched implementation, batched mode is added to the DualGemm (under examples/45_dual_gemm). DualGemmMode::kBatched and SplitKSerial are not compatible: Status::kErrorInvalidProblem is returned if both are set. * Decouple LayoutB0 and LayoutB1 in DualGemm The DualGemm template assumed the same layout, LayoutB, for both right operand matrices B0 and B1. This is problematic if the layout of the two matrices is different. In particular, this may be the case when one of the matrices is row-major, while the other is a (column) vector that has to be broadcasted in column-major with zero stride (e.g., as {B1.device_data(), 0}) for the DualGemm implementation to be able to process B0 and B1 simultaneously. In this commit, LayoutB0 and LayoutB1 are decoupled throughout the DualGemm code (device, kernel, and mma). Additionally, the batch strides of B0 and B1 are also decoupled to accommodate the column vector B1 case described above. * Remove comment as no longer relevant * Revert Fix MHA kernel --------- Co-authored-by: mikeiovine --- examples/45_dual_gemm/device/dual_gemm.h | 110 +++++-- examples/45_dual_gemm/dual_gemm.cu | 210 +++++++++++- examples/45_dual_gemm/dual_gemm_common.h | 52 +++ examples/45_dual_gemm/dual_gemm_run.h | 304 ++++++++++++------ examples/45_dual_gemm/kernel/dual_gemm.h | 140 +++++--- .../45_dual_gemm/threadblock/dual_mma_base.h | 66 ++-- .../threadblock/dual_mma_multistage.h | 216 +++++++------ 7 files changed, 793 insertions(+), 305 deletions(-) create mode 100644 examples/45_dual_gemm/dual_gemm_common.h diff --git a/examples/45_dual_gemm/device/dual_gemm.h b/examples/45_dual_gemm/device/dual_gemm.h index 491888bef..71f7973ed 100644 --- a/examples/45_dual_gemm/device/dual_gemm.h +++ b/examples/45_dual_gemm/device/dual_gemm.h @@ -52,6 +52,7 @@ D2 = element_wise(D0, D1) #include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" #include "../kernel/dual_gemm.h" +#include "../dual_gemm_common.h" //////////////////////////////////////////////////////////////////////////////// @@ -68,8 +69,10 @@ template < typename LayoutA_, /// Element type for B matrix operand typename ElementB_, - /// Layout type for B matrix operand - typename LayoutB_, + /// Layout type for B0 matrix operand + typename LayoutB0_, + /// Layout type for B1 matrix operand + typename LayoutB1_, /// Element type for C and D matrix operands typename ElementC_, /// Layout type for C and D matrix operands @@ -119,8 +122,10 @@ class DualGemm { using LayoutA = LayoutA_; using TensorRefA = TensorRef; using ElementB = ElementB_; - using LayoutB = LayoutB_; - using TensorRefB = TensorRef; + using LayoutB0 = LayoutB0_; + using LayoutB1 = LayoutB1_; + using TensorRefB0 = TensorRef; + using TensorRefB1 = TensorRef; using ElementC = ElementC_; using LayoutC = LayoutC_; using TensorRefC = TensorRef; @@ -151,23 +156,31 @@ class DualGemm { /// Define the threadblock-scoped matrix multiply-accumulate static_assert(ArchTag::kMinComputeCapability >= 80, "Only multistage is implemented"); static_assert(kStages >= 3, "Only multistage is implemented"); - using Mma = typename cutlass::gemm::threadblock::DefaultMma< - ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + using Mma0 = typename cutlass::gemm::threadblock::DefaultMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB0, kAlignmentB, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, + ThreadblockShape, WarpShape, + InstructionShape, Stages, Operator>::ThreadblockMma; + using Mma1 = typename cutlass::gemm::threadblock::DefaultMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB1, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, Stages, Operator>::ThreadblockMma; using DualMma = threadblock::DualMmaMultistage< - typename Mma::Shape, - typename Mma::IteratorA, - typename Mma::SmemIteratorA, - Mma::kCacheOpA, - typename Mma::IteratorB, - typename Mma::SmemIteratorB, - Mma::kCacheOpB, - typename Mma::ElementC, - typename Mma::LayoutC, - typename Mma::Policy, - Mma::kStages, + typename Mma0::Shape, + typename Mma0::IteratorA, + typename Mma0::SmemIteratorA, + Mma0::kCacheOpA, + typename Mma0::IteratorB, + typename Mma0::SmemIteratorB, + Mma0::kCacheOpB, + typename Mma1::IteratorB, + typename Mma1::SmemIteratorB, + typename Mma0::ElementC, + typename Mma0::LayoutC, + typename Mma0::Policy, + typename Mma1::Policy, + Mma0::kStages, SharedMemoryClearOption::kNone >; @@ -176,11 +189,11 @@ class DualGemm { /// Define the epilogue using Epilogue0 = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, typename DualMma::Operator, kPartitionsK, EpilogueOutputOp0, + ThreadblockShape, typename DualMma::Operator0, kPartitionsK, EpilogueOutputOp0, EpilogueOutputOp0::kCount>::Epilogue; using Epilogue1 = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< - ThreadblockShape, typename DualMma::Operator, kPartitionsK, EpilogueOutputOp1, + ThreadblockShape, typename DualMma::Operator1, kPartitionsK, EpilogueOutputOp1, EpilogueOutputOp1::kCount>::Epilogue; /// Define the kernel-level GEMM operator. @@ -197,12 +210,13 @@ class DualGemm { // Data members // + DualGemmMode mode; GemmCoord problem_size; TensorRef ref_A0; - TensorRef ref_B0; + TensorRef ref_B0; TensorRef ref_C0; TensorRef ref_D0; - TensorRef ref_B1; + TensorRef ref_B1; TensorRef ref_C1; TensorRef ref_D1; TensorRef ref_D2; @@ -211,6 +225,13 @@ class DualGemm { typename EpilogueOutputOp2::Params epilogue2; int split_k_slices; + int batch_count; + int64_t batch_stride_A; + int64_t batch_stride_B0; + int64_t batch_stride_B1; + int64_t batch_stride_C; + int64_t batch_stride_D; + // // Methods // @@ -224,12 +245,13 @@ class DualGemm { /// Constructs an Arguments structure CUTLASS_HOST_DEVICE Arguments( + DualGemmMode mode, GemmCoord problem_size_, TensorRef ref_A0_, - TensorRef ref_B0_, + TensorRef ref_B0_, TensorRef ref_C0_, TensorRef ref_D0_, - TensorRef ref_B1_, + TensorRef ref_B1_, TensorRef ref_C1_, TensorRef ref_D1_, TensorRef ref_D2_, @@ -239,8 +261,15 @@ class DualGemm { typename EpilogueOutputOp1::Params(), typename EpilogueOutputOp2::Params epilogue2_ = typename EpilogueOutputOp2::Params(), - int split_k_slices_ = 1 + int split_k_slices_ = 1, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B0 = 0, + int64_t batch_stride_B1 = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0 ): + mode(mode), problem_size(problem_size_), ref_A0(ref_A0_), ref_B0(ref_B0_), @@ -253,7 +282,13 @@ class DualGemm { epilogue0(epilogue0_), epilogue1(epilogue1_), epilogue2(epilogue2_), - split_k_slices(split_k_slices_) { + split_k_slices(split_k_slices_), + batch_count(batch_count), + batch_stride_A(batch_stride_A), + batch_stride_B0(batch_stride_B0), + batch_stride_B1(batch_stride_B1), + batch_stride_C(batch_stride_C), + batch_stride_D(batch_stride_D) { } }; @@ -271,6 +306,9 @@ public: /// Determines whether the GEMM can execute the given problem. static Status can_implement(Arguments const &args) { + if (args.mode == DualGemmMode::kBatched && kSplitKSerial) { + return Status::kErrorInvalidProblem; + } if (!kSplitKSerial && args.split_k_slices > 1) { return Status::kErrorInvalidProblem; } @@ -304,17 +342,15 @@ public: static size_t get_workspace_size(Arguments const &args) { size_t bytes = 0; - - // Determine grid shape - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( - args.problem_size, - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, - args.split_k_slices); if (kSplitKSerial && args.split_k_slices > 1) { + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); } @@ -331,7 +367,7 @@ public: cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, - args.split_k_slices); + args.mode == DualGemmMode::kBatched ? args.batch_count : args.split_k_slices); if (kSplitKSerial) { if (args.split_k_slices > 1) { @@ -357,6 +393,7 @@ public: // Initialize the Params structure params_ = typename DualGemmKernel::Params{ + args.mode, args.problem_size, grid_shape, args.ref_A0.non_const_ref(), @@ -371,6 +408,11 @@ public: args.epilogue1, args.epilogue2, reinterpret_cast(workspace), + args.batch_stride_A, + args.batch_stride_B0, + args.batch_stride_B1, + args.batch_stride_C, + args.batch_stride_D, }; return Status::kSuccess; diff --git a/examples/45_dual_gemm/dual_gemm.cu b/examples/45_dual_gemm/dual_gemm.cu index 15974e033..75ef15020 100644 --- a/examples/45_dual_gemm/dual_gemm.cu +++ b/examples/45_dual_gemm/dual_gemm.cu @@ -43,8 +43,6 @@ D2 = element_wise(D0, D1) D0 and D1 will be optionally stored in gmem (`kStoreD0` / `kStoreD1`) */ -// #define IS_PROFILING - #include #include "cutlass/cutlass.h" @@ -66,10 +64,12 @@ D2 = element_wise(D0, D1) //////////////////////////////////////////////////////////////////////////////// cutlass::gemm::GemmCoord problem_size(4096, 4096, 8192); +cutlass::gemm::GemmCoord batch_problem_size(321, 256, 512); constexpr int kStages = 3; constexpr bool kSplitKSerial = false; constexpr bool kUseBias = true; +constexpr int kBatchCount = 37; #if 0 @@ -165,7 +165,16 @@ bool run_nonfused_gemm_f16_sm80() { NonFusedDualGemmRun nonFusedGemm; std::cout << "Running Non-fused GEMMs FP16 TN GEMMs...\n"; - bool pass = nonFusedGemm.run(problem_size, alpha0, beta0, alpha1, beta1); + + bool pass = nonFusedGemm.run( + problem_size, + alpha0, + beta0, + alpha1, + beta1, + true /* is_profiling */ + ); + if(pass) std::cout << "Pass\n"; else @@ -215,6 +224,7 @@ bool run_fused_gemm_f16_sm80_shmem() { cutlass::layout::RowMajor, ElementOperandB, cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, @@ -236,24 +246,212 @@ bool run_fused_gemm_f16_sm80_shmem() { DualFusedGemmRun fusedGemm; std::cout << "Running Fused FP16 TN GEMMs + Epilogue2...\n"; - bool passed = fusedGemm.run(problem_size, alpha0, beta0, alpha1, beta1); + + bool passed = fusedGemm.run( + problem_size, + alpha0, + beta0, + alpha1, + beta1 + ); + if(passed) std::cout << "Pass\n"; else std::cout << "Fail\n"; return passed; +} +bool run_batched_fused_gemm_f16_sm80_shmem() { + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + // Optionally, we might not need intermediate GEMM outputs + constexpr bool kStoreD0 = true; + constexpr bool kStoreD1 = true; + + using DualGemm = cutlass::gemm::device::DualGemm< + ElementOperandA, + cutlass::layout::RowMajor, + ElementOperandB, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + EpilogueOutputOp2, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, + kStages, + kStoreD0, + kStoreD1, + kSplitKSerial + >; + + DualFusedGemmRun fusedGemm; + + std::cout << "Running Batched Fused FP16 TN GEMMs + Epilogue2...\n"; + + bool passed = fusedGemm.run( + batch_problem_size, + alpha0, + beta0, + alpha1, + beta1, + kBatchCount, + false, /* broadcast_b1 */ + false /* is_profiling */ + ); + + if(passed) + std::cout << "Pass\n"; + else + std::cout << "Fail\n"; + + return passed; +} + +bool run_broadcast_fused_gemm_f16_sm80_shmem() { + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + // Optionally, we might not need intermediate GEMM outputs + constexpr bool kStoreD0 = true; + constexpr bool kStoreD1 = true; + + using DualGemm = cutlass::gemm::device::DualGemm< + ElementOperandA, + cutlass::layout::RowMajor, + ElementOperandB, + // different LayoutB0 and B1 + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + EpilogueOutputOp2, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, + kStages, + kStoreD0, + kStoreD1, + kSplitKSerial + >; + + DualFusedGemmRun fusedGemm; + + std::cout << "Running Broadcast Fused FP16 TN GEMMs + Epilogue2...\n"; + + bool passed = fusedGemm.run( + problem_size, + alpha0, + beta0, + alpha1, + beta1, + 1, /* batch_count */ + true, /* broadcast_b1 */ + true /* is_profiling */ + ); + + if(passed) + std::cout << "Pass\n"; + else + std::cout << "Fail\n"; + + return passed; +} + +bool run_batched_broadcast_fused_gemm_f16_sm80_shmem() { + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + // Optionally, we might not need intermediate GEMM outputs + constexpr bool kStoreD0 = true; + constexpr bool kStoreD1 = true; + + using DualGemm = cutlass::gemm::device::DualGemm< + ElementOperandA, + cutlass::layout::RowMajor, + ElementOperandB, + // different LayoutB0 and B1 + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + EpilogueOutputOp2, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, + kStages, + kStoreD0, + kStoreD1, + kSplitKSerial + >; + + DualFusedGemmRun fusedGemm; + + std::cout << "Running Batch Broadcast Fused FP16 TN GEMMs + Epilogue2...\n"; + + bool passed = fusedGemm.run( + batch_problem_size, + alpha0, + beta0, + alpha1, + beta1, + kBatchCount, + true, /* broadcast_b1 */ + false /* is_profiling */ + ); + + if(passed) + std::cout << "Pass\n"; + else + std::cout << "Fail\n"; + + return passed; } int main() { std::vectorfuncs = { &run_nonfused_gemm_f16_sm80, - &run_fused_gemm_f16_sm80_shmem + &run_fused_gemm_f16_sm80_shmem, + &run_batched_fused_gemm_f16_sm80_shmem, + &run_broadcast_fused_gemm_f16_sm80_shmem, + &run_batched_broadcast_fused_gemm_f16_sm80_shmem }; - std::string test_name = "dual-gemm f16 bias=" + std::to_string(kUseBias) + " split_k_serial=" + std::to_string(kSplitKSerial); + std::string test_name = ( + "dual-gemm f16 bias=" + + std::to_string(kUseBias) + + " split_k_serial=" + + std::to_string(kSplitKSerial) + + " batch_count=" + + std::to_string(kBatchCount) + ); + return testRun(80, funcs, test_name); } diff --git a/examples/45_dual_gemm/dual_gemm_common.h b/examples/45_dual_gemm/dual_gemm_common.h new file mode 100644 index 000000000..615b480ef --- /dev/null +++ b/examples/45_dual_gemm/dual_gemm_common.h @@ -0,0 +1,52 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines common types used for all DualGemm operators. +*/ +#pragma once + +namespace cutlass { +namespace gemm { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +enum class DualGemmMode { + kGemm, + kBatched, + kInvalid +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/45_dual_gemm/dual_gemm_run.h b/examples/45_dual_gemm/dual_gemm_run.h index 63ca2ac84..d6a52d58e 100644 --- a/examples/45_dual_gemm/dual_gemm_run.h +++ b/examples/45_dual_gemm/dual_gemm_run.h @@ -33,6 +33,7 @@ #include #include #include +#include #include "cutlass/util/host_tensor.h" #include "cutlass/util/tensor_view_io.h" @@ -44,6 +45,10 @@ #include "cutlass/util/reference/device/gemm.h" #include "cutlass/util/reference/device/tensor_relu.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/device/gemm_universal.h" + +#include "dual_gemm_common.h" #include "helper.h" #define CHECK_GT(val1, val2) \ @@ -205,6 +210,7 @@ struct NonFusedDualGemmRun ElementCompute beta0 = ElementCompute(0), ElementCompute alpha1 = ElementCompute(1), ElementCompute beta1 = ElementCompute(0), + bool is_profiling = true, bool relu = false, int warm_ups = 1, int runs = 100) { @@ -335,40 +341,41 @@ struct NonFusedDualGemmRun status = gemm_op_1(); CUTLASS_CHECK(status); } -#ifdef IS_PROFILING - return true; -#endif - // - // Run the GEMM - // - cudaEvent_t start, stop1, stop2; - cudaEventCreate(&start); - cudaEventCreate(&stop1); - cudaEventCreate(&stop2); - cudaEventRecord(start); + if (is_profiling) { + // + // Profile the GEMM + // - for(int i = 0; i < runs; i++) { - status = gemm_op_0(); - - CUTLASS_CHECK(status); + cudaEvent_t start, stop1, stop2; + cudaEventCreate(&start); + cudaEventCreate(&stop1); + cudaEventCreate(&stop2); + + cudaEventRecord(start); + + for(int i = 0; i < runs; i++) { + status = gemm_op_0(); + + CUTLASS_CHECK(status); + } + cudaEventRecord(stop1); + for(int i = 0; i < runs; i++) { + status = gemm_op_1(); + + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop2); + cudaDeviceSynchronize(); + float gemm0Time, gemm1Time, totalTime; + cudaEventElapsedTime(&gemm0Time, start, stop1); + cudaEventElapsedTime(&gemm1Time, stop1, stop2); + cudaEventElapsedTime(&totalTime, start, stop2); + std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n"; + std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n"; + std::cout << "Non-fusion GEMM only time " << totalTime / (float)runs << " ms\n"; } - cudaEventRecord(stop1); - for(int i = 0; i < runs; i++) { - status = gemm_op_1(); - - CUTLASS_CHECK(status); - } - - cudaEventRecord(stop2); - cudaDeviceSynchronize(); - float gemm0Time, gemm1Time, totalTime; - cudaEventElapsedTime(&gemm0Time, start, stop1); - cudaEventElapsedTime(&gemm1Time, stop1, stop2); - cudaEventElapsedTime(&totalTime, start, stop2); - std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n"; - std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n"; - std::cout << "Non-fusion GEMM only time " << totalTime / (float)runs << " ms\n"; tensor_D0.sync_host(); tensor_D1.sync_host(); @@ -543,6 +550,9 @@ struct DualFusedGemmRun ElementCompute beta0 = ElementCompute(1), ElementCompute alpha1 = ElementCompute(1), ElementCompute beta1 = ElementCompute(1), + int batch_count = 1, + bool broadcast_b1 = false, + bool is_profiling = true, bool relu = false, int warm_ups = 1, int runs = 100) { @@ -553,55 +563,91 @@ struct DualFusedGemmRun cutlass::HostTensor< typename DualGemm::ElementA, - typename DualGemm::LayoutA> tensor_A0(problem_size.mk()); + typename DualGemm::LayoutA> tensor_A0( + std::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.k()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.k())); cutlass::HostTensor< typename DualGemm::ElementB, - typename DualGemm::LayoutB> tensor_B0(problem_size.kn()); + typename DualGemm::LayoutB0> tensor_B0( + std::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.k(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementC, - typename DualGemm::LayoutC> tensor_C0(problem_size.mn()); + typename DualGemm::LayoutC> tensor_C0( + std::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementC, - typename DualGemm::LayoutScaleBias> tensor_Bias0({1, problem_size.n()}); + typename DualGemm::LayoutScaleBias> tensor_Bias0({batch_count, problem_size.n()}); cutlass::HostTensor< typename DualGemm::ElementC, - typename DualGemm::LayoutC> tensor_D0(problem_size.mn()); + typename DualGemm::LayoutC> tensor_D0( + std::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementC, - typename DualGemm::LayoutC> reference_D0(problem_size.mn()); + typename DualGemm::LayoutC> reference_D0( + std::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementB, - typename DualGemm::LayoutB> tensor_B1(problem_size.kn()); + typename DualGemm::LayoutB1> tensor_B1( + std::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.k(), batch_count * problem_size.n())); + if (broadcast_b1) { + tensor_B1.resize({problem_size.k(), batch_count}); + } cutlass::HostTensor< typename DualGemm::ElementC, - typename DualGemm::LayoutC> tensor_C1(problem_size.mn()); + typename DualGemm::LayoutC> tensor_C1( + std::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementC, - typename DualGemm::LayoutScaleBias> tensor_Bias1({1, problem_size.n()}); + typename DualGemm::LayoutScaleBias> tensor_Bias1({batch_count, problem_size.n()}); cutlass::HostTensor< typename DualGemm::ElementC, - typename DualGemm::LayoutC> tensor_D1(problem_size.mn()); + typename DualGemm::LayoutC> tensor_D1( + std::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementC, - typename DualGemm::LayoutC> tensor_D2(problem_size.mn()); + typename DualGemm::LayoutC> tensor_D2( + std::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementC, - typename DualGemm::LayoutC> reference_D1(problem_size.mn()); + typename DualGemm::LayoutC> reference_D1( + std::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementC, - typename DualGemm::LayoutC> reference_D2(problem_size.mn()); + typename DualGemm::LayoutC> reference_D2( + std::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2118)); @@ -638,6 +684,20 @@ struct DualFusedGemmRun reference_D1.sync_device(); reference_D2.sync_device(); + // + // Batch strides (irrelevant when batch_count == 1) + // + + int64_t batch_stride_A = problem_size.m() * problem_size.k(); + int64_t batch_stride_B0 = problem_size.k() * problem_size.n(); + int64_t batch_stride_B1 = problem_size.k() * problem_size.n(); + if (broadcast_b1) { + // B1 is a (column) vector + batch_stride_B1 = problem_size.k(); + } + int64_t batch_stride_Bias = problem_size.n(); + int64_t batch_stride_D = problem_size.m() * problem_size.n(); + // // Initialize the GEMM operator // @@ -652,21 +712,36 @@ struct DualFusedGemmRun ref_B1 = {tensor_Bias1.device_data(), typename DualGemm::LayoutC::Stride(0)}; } typename DualGemm::Arguments arguments{ + (batch_count > 1 ? + cutlass::gemm::DualGemmMode::kBatched : + cutlass::gemm::DualGemmMode::kGemm), problem_size, tensor_A0.device_ref(), tensor_B0.device_ref(), ref_B0, DualGemm::kStoreD0 ? tensor_D0.device_ref() : nullptr_ref, - tensor_B1.device_ref(), + (broadcast_b1 ? + typename DualGemm::TensorRefB1(tensor_B1.device_data(), 0) : + tensor_B1.device_ref()), ref_B1, DualGemm::kStoreD1 ? tensor_D1.device_ref() : nullptr_ref, tensor_D2.device_ref(), {alpha0, beta0}, {alpha1, beta1}, {}, - split_k_slices + split_k_slices, + batch_count, + batch_stride_A, + batch_stride_B0, + batch_stride_B1, + batch_stride_Bias, + batch_stride_D, }; + // + // Run the GEMM + // + DualGemm b2b_gemm_op; cutlass::device_memory::allocation workspace(b2b_gemm_op.get_workspace_size(arguments)); @@ -684,31 +759,29 @@ struct DualFusedGemmRun CUTLASS_CHECK(status); } -#ifdef IS_PROFILING - return true; -#endif - // - // Run the GEMM - // + if (is_profiling) { + // + // Profile the GEMM + // - cudaEvent_t start, stop; - cudaEventCreate(&start); - cudaEventCreate(&stop); + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); - cudaEventRecord(start); + cudaEventRecord(start); - for(int i = 0; i < runs; i++) { - status = b2b_gemm_op(); + for(int i = 0; i < runs; i++) { + status = b2b_gemm_op(); + CUTLASS_CHECK(status); + } - CUTLASS_CHECK(status); + cudaEventRecord(stop); + cudaDeviceSynchronize(); + float gemmTime; + cudaEventElapsedTime(&gemmTime, start, stop); + std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n"; } - cudaEventRecord(stop); - cudaDeviceSynchronize(); - float gemmTime; - cudaEventElapsedTime(&gemmTime, start, stop); - std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n"; - tensor_D0.sync_host(); tensor_D1.sync_host(); tensor_D2.sync_host(); @@ -717,45 +790,81 @@ struct DualFusedGemmRun // Verify // - cutlass::reference::device::Gemm< - typename DualGemm::ElementA, typename DualGemm::LayoutA, - typename DualGemm::ElementB, typename DualGemm::LayoutB, - typename DualGemm::ElementC, typename DualGemm::LayoutC, - ElementAccumulator, ElementAccumulator> - reference_gemm_0; + using GemmUniversal0 = cutlass::gemm::device::GemmUniversal< + typename DualGemm::ElementA, typename DualGemm::LayoutA, + typename DualGemm::ElementB, typename DualGemm::LayoutB0, + typename DualGemm::ElementC, typename DualGemm::LayoutC, + ElementAccumulator + >; - cutlass::reference::device::Gemm< - typename DualGemm::ElementA, typename DualGemm::LayoutA, - typename DualGemm::ElementB, typename DualGemm::LayoutB, - typename DualGemm::ElementC, typename DualGemm::LayoutC, ElementCompute, - ElementAccumulator, typename DualGemm::Operator> - reference_gemm_1; + GemmUniversal0 reference_gemm0; - reference_gemm_0( + typename GemmUniversal0::Arguments args0 { + (batch_count > 1 ? + cutlass::gemm::GemmUniversalMode::kBatched : + cutlass::gemm::GemmUniversalMode::kGemm), problem_size, - alpha0, - tensor_A0.device_ref(), - tensor_B0.device_ref(), - beta0, - {tensor_Bias0.device_data(), typename DualGemm::LayoutC::Stride(0)}, - reference_D0.device_ref() - ); + batch_count, + {alpha0, beta0}, + tensor_A0.device_data(), + tensor_B0.device_data(), + tensor_Bias0.device_data(), + reference_D0.device_data(), + batch_stride_A, + batch_stride_B0, + batch_stride_Bias, + batch_stride_D, + tensor_A0.stride(0), + tensor_B0.stride(0), + 0, // zero stride for the bias vector + reference_D0.stride(0), + }; + + status = reference_gemm0.can_implement(args0); + CUTLASS_CHECK(status); + status = reference_gemm0(args0); + CUTLASS_CHECK(status); + + using GemmUniversal1 = cutlass::gemm::device::GemmUniversal< + typename DualGemm::ElementA, typename DualGemm::LayoutA, + typename DualGemm::ElementB, typename DualGemm::LayoutB1, + typename DualGemm::ElementC, typename DualGemm::LayoutC, + ElementAccumulator + >; + + GemmUniversal1 reference_gemm1; + + typename GemmUniversal1::Arguments args1 { + (batch_count > 1 ? + cutlass::gemm::GemmUniversalMode::kBatched : + cutlass::gemm::GemmUniversalMode::kGemm), + problem_size, + batch_count, + {alpha1, beta1}, + tensor_A0.device_data(), + tensor_B1.device_data(), + tensor_Bias1.device_data(), + reference_D1.device_data(), + batch_stride_A, + batch_stride_B1, + batch_stride_Bias, + batch_stride_D, + tensor_A0.stride(0), + (broadcast_b1 ? 0 : tensor_B1.stride(0)), + 0, // zero stride for the bias vector + reference_D1.stride(0), + }; + + status = reference_gemm1.can_implement(args1); + CUTLASS_CHECK(status); + status = reference_gemm1(args1); + CUTLASS_CHECK(status); + if(relu) { cutlass::reference::device::TensorReLu(reference_D0.device_view()); - } - - reference_gemm_1( - problem_size, - alpha1, - tensor_A0.device_ref(), - tensor_B1.device_ref(), - beta1, - {tensor_Bias1.device_data(), typename DualGemm::LayoutC::Stride(0)}, - reference_D1.device_ref() - ); - if(relu) { cutlass::reference::device::TensorReLu(reference_D1.device_view()); } + TensorEpilogueForEach(reference_D0.device_view(), reference_D1.device_view(), reference_D2.device_view()); cudaDeviceSynchronize(); reference_D0.sync_host(); @@ -793,7 +902,6 @@ struct DualFusedGemmRun bool passed = passed_out0 && passed_out1 && passed_out2; if (!passed) { - std::stringstream fname; fname << "error_DualGemm_device_fused.txt"; diff --git a/examples/45_dual_gemm/kernel/dual_gemm.h b/examples/45_dual_gemm/kernel/dual_gemm.h index 4cbddaa77..56ed9e7ea 100644 --- a/examples/45_dual_gemm/kernel/dual_gemm.h +++ b/examples/45_dual_gemm/kernel/dual_gemm.h @@ -42,6 +42,7 @@ #include "../threadblock/dual_mma_multistage.h" #include "../threadblock/dual_epilogue.h" +#include "../dual_gemm_common.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -92,6 +93,10 @@ struct DualGemm { true // IterationsUnroll >; + using ElementA = typename DualMma::IteratorA::Element; + using ElementB = typename DualMma::IteratorB0::Element; + using ElementC = typename DualEpilogue::OutputTileIterator::Element; + static bool const kSplitKSerial = SplitKSerial; static_assert(!kSplitKSerial || (kStoreD0 && kStoreD1), "Split-K serial requires buffers for D0/D1 for reduction"); @@ -102,6 +107,7 @@ struct DualGemm { /// Parameters structure struct Params { + DualGemmMode mode; cutlass::gemm::GemmCoord problem_size; cutlass::gemm::GemmCoord grid_tiled_shape; int swizzle_log_tile; @@ -109,8 +115,8 @@ struct DualGemm { // Mma0 typename DualMma::IteratorA::Params params_A0; typename DualMma::IteratorA::TensorRef ref_A0; - typename DualMma::IteratorB::Params params_B0; - typename DualMma::IteratorB::TensorRef ref_B0; + typename DualMma::IteratorB0::Params params_B0; + typename DualMma::IteratorB0::TensorRef ref_B0; typename Epilogue0::OutputTileIterator::Params params_C0; typename Epilogue0::OutputTileIterator::TensorRef ref_C0; typename Epilogue0::OutputTileIterator::Params params_D0; @@ -118,8 +124,8 @@ struct DualGemm { typename OutputOp0::Params output_op_0; // Mma1 - typename DualMma::IteratorB::Params params_B1; - typename DualMma::IteratorB::TensorRef ref_B1; + typename DualMma::IteratorB1::Params params_B1; + typename DualMma::IteratorB1::TensorRef ref_B1; typename Epilogue1::OutputTileIterator::Params params_C1; typename Epilogue1::OutputTileIterator::TensorRef ref_C1; typename Epilogue1::OutputTileIterator::Params params_D1; @@ -133,6 +139,12 @@ struct DualGemm { int *semaphore; int gemm_k_size; + int64_t batch_stride_A; + int64_t batch_stride_B0; + int64_t batch_stride_B1; + int64_t batch_stride_C; + int64_t batch_stride_D; + // // Methods // @@ -142,15 +154,16 @@ struct DualGemm { CUTLASS_HOST_DEVICE Params( + DualGemmMode mode, cutlass::gemm::GemmCoord const & problem_size, cutlass::gemm::GemmCoord const & grid_tiled_shape, // Mma0: D0 = A @ B0 + C0 typename DualMma::IteratorA::TensorRef ref_A0, - typename DualMma::IteratorB::TensorRef ref_B0, + typename DualMma::IteratorB0::TensorRef ref_B0, typename Epilogue0::OutputTileIterator::TensorRef ref_C0, typename Epilogue0::OutputTileIterator::TensorRef ref_D0, // Mma1: D1 = A @ B1 + C1 - typename DualMma::IteratorB::TensorRef ref_B1, + typename DualMma::IteratorB1::TensorRef ref_B1, typename Epilogue1::OutputTileIterator::TensorRef ref_C1, typename Epilogue1::OutputTileIterator::TensorRef ref_D1, @@ -158,8 +171,14 @@ struct DualGemm { typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(), typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(), typename OutputOp2::Params output_op_2 = typename OutputOp2::Params(), - int *workspace = nullptr + int *workspace = nullptr, + int64_t batch_stride_A = 1, + int64_t batch_stride_B0 = 1, + int64_t batch_stride_B1 = 1, + int64_t batch_stride_C = 1, + int64_t batch_stride_D = 1 ): + mode(mode), problem_size(problem_size), grid_tiled_shape(grid_tiled_shape), swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), @@ -183,13 +202,18 @@ struct DualGemm { ref_D2(ref_D2), output_op_0(output_op_0), output_op_1(output_op_1), - output_op_2(output_op_2) { + output_op_2(output_op_2), + batch_stride_A(batch_stride_A), + batch_stride_B0(batch_stride_B0), + batch_stride_B1(batch_stride_B1), + batch_stride_C(batch_stride_C), + batch_stride_D(batch_stride_D) { int total_gemm_k_iterations = (problem_size.k() + DualMma::Shape::kK - 1) / DualMma::Shape::kK; int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); gemm_k_size = gemm_k_iterations * DualMma::Shape::kK; - semaphore = workspace; + semaphore = workspace; } }; @@ -210,16 +234,16 @@ struct DualGemm { static Status can_implement( cutlass::gemm::GemmCoord const & problem_size, typename DualMma::IteratorA::TensorRef ref_A0, - typename DualMma::IteratorB::TensorRef ref_B0, + typename DualMma::IteratorB0::TensorRef ref_B0, typename Epilogue0::OutputTileIterator::TensorRef ref_C0, typename Epilogue0::OutputTileIterator::TensorRef ref_D0, - typename DualMma::IteratorB::TensorRef ref_B1, + typename DualMma::IteratorB1::TensorRef ref_B1, typename Epilogue1::OutputTileIterator::TensorRef ref_C1, typename Epilogue1::OutputTileIterator::TensorRef ref_D1, typename Epilogue1::OutputTileIterator::TensorRef ref_D2) { static int const kAlignmentA = DualMma::IteratorA::AccessType::kElements; - static int const kAlignmentB = DualMma::IteratorB::AccessType::kElements; + static int const kAlignmentB = DualMma::IteratorB0::AccessType::kElements; static int const kAlignmentC = Epilogue0::OutputTileIterator::kElementsPerAccess; if (!TensorRef_aligned(ref_A0, kAlignmentA)) { @@ -273,52 +297,66 @@ struct DualGemm { return; } + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA *ptr_A0 = static_cast(params.ref_A0.data()); + ElementB *ptr_B0 = static_cast(params.ref_B0.data()); + ElementB *ptr_B1 = static_cast(params.ref_B1.data()); + + // + // Fetch pointers based on mode. + // + if (params.mode == DualGemmMode::kGemm) { + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } + else if (params.mode == DualGemmMode::kBatched) { + ptr_A0 += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B0 += threadblock_tile_offset.k() * params.batch_stride_B0; + ptr_B1 += threadblock_tile_offset.k() * params.batch_stride_B1; + } + // Compute initial location in logical coordinates cutlass::MatrixCoord tb_offset_A0{ threadblock_tile_offset.m() * DualMma::Shape::kM, - threadblock_tile_offset.k() * params.gemm_k_size, + offset_k, }; cutlass::MatrixCoord tb_offset_B0{ - threadblock_tile_offset.k() * params.gemm_k_size, + offset_k, threadblock_tile_offset.n() * DualMma::Shape::kN }; cutlass::MatrixCoord tb_offset_B1{ - threadblock_tile_offset.k() * params.gemm_k_size, + offset_k, threadblock_tile_offset.n() * DualMma::Shape::kN }; - // Problem size is a function of threadblock index in the K dimension - int problem_size_k = - (params.problem_size.k() < (threadblock_tile_offset.k() + 1) * params.gemm_k_size) ? - params.problem_size.k() : - (threadblock_tile_offset.k() + 1) * params.gemm_k_size; - - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size_k - tb_offset_A0.column() + DualMma::Shape::kK - 1) / DualMma::Shape::kK; - // Compute position within threadblock int thread_idx = threadIdx.x; // Construct iterators to A and B operands typename DualMma::IteratorA iterator_A0( params.params_A0, - params.ref_A0.data(), + ptr_A0, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A0); - typename DualMma::IteratorB iterator_B0( + typename DualMma::IteratorB0 iterator_B0( params.params_B0, - params.ref_B0.data(), + ptr_B0, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B0); - typename DualMma::IteratorB iterator_B1( + typename DualMma::IteratorB1 iterator_B1( params.params_B1, - params.ref_B1.data(), + ptr_B1, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B1); @@ -340,6 +378,9 @@ struct DualGemm { accum0.clear(); accum1.clear(); + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + DualMma::Shape::kK - 1) / DualMma::Shape::kK; + DualMma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); if (!kSplitKSerial || gemm_k_iterations > 0) { // Compute threadblock-scoped matrix multiply-add @@ -372,31 +413,46 @@ struct DualGemm { int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + ElementC *ptr_C0 = static_cast(params.ref_C0.data()); + ElementC *ptr_C1 = static_cast(params.ref_C1.data()); + ElementC *ptr_D0 = static_cast(params.ref_D0.data()); + ElementC *ptr_D1 = static_cast(params.ref_D1.data()); + ElementC *ptr_D2 = static_cast(params.ref_D2.data()); + // Construct the semaphore. Semaphore semaphore(params.semaphore + block_idx, thread_idx); - // If performing a reduction via split-K, fetch the initial synchronization - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - - // Fetch the synchronization lock initially but do not block. - semaphore.fetch(); + if (params.mode == DualGemmMode::kGemm) { + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); - // Indicate which position in a serial reduction the output operator is currently updating - output_op_0.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); - output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + // Indicate which position in a serial reduction the output operator is currently updating + output_op_0.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + } + else if (params.mode == DualGemmMode::kBatched) { + ptr_C0 += threadblock_tile_offset.k() * params.batch_stride_C; + ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C; + ptr_D0 += threadblock_tile_offset.k() * params.batch_stride_D; + ptr_D1 += threadblock_tile_offset.k() * params.batch_stride_D; + ptr_D2 += threadblock_tile_offset.k() * params.batch_stride_D; } // Tile iterator loading from source tensor. typename Epilogue0::OutputTileIterator iterator_C0( params.params_C0, - params.ref_C0.data(), + ptr_C0, params.problem_size.mn(), thread_idx, threadblock_offset ); typename Epilogue1::OutputTileIterator iterator_C1( params.params_C1, - params.ref_C1.data(), + ptr_C1, params.problem_size.mn(), thread_idx, threadblock_offset @@ -405,21 +461,21 @@ struct DualGemm { // Tile iterator writing to destination tensor. typename Epilogue0::OutputTileIterator iterator_D0( params.params_D0, - params.ref_D0.data(), + ptr_D0, params.problem_size.mn(), thread_idx, threadblock_offset ); typename Epilogue1::OutputTileIterator iterator_D1( params.params_D1, - params.ref_D1.data(), + ptr_D1, params.problem_size.mn(), thread_idx, threadblock_offset ); typename Epilogue1::OutputTileIterator iterator_D2( params.params_D2, - params.ref_D2.data(), + ptr_D2, params.problem_size.mn(), thread_idx, threadblock_offset diff --git a/examples/45_dual_gemm/threadblock/dual_mma_base.h b/examples/45_dual_gemm/threadblock/dual_mma_base.h index 10563e704..7031781ad 100644 --- a/examples/45_dual_gemm/threadblock/dual_mma_base.h +++ b/examples/45_dual_gemm/threadblock/dual_mma_base.h @@ -58,7 +58,9 @@ template < /// Size of the Gemm problem - concept: gemm::GemmShape<> typename Shape_, /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, + typename Policy0_, + /// B1-specific version of the policy (concept: MmaPolicy) + typename Policy1_, /// Number of stages, int Stages, /// Used for partial specialization @@ -69,18 +71,20 @@ class DualMmaBase { using Shape = Shape_; ///< Policy describing tuning details - using Policy = Policy_; + using Policy0 = Policy0_; + using Policy1 = Policy1_; // // Dependent types // /// Warp-level Mma - using Operator = typename Policy::Operator; + using Operator0 = typename Policy0::Operator; + using Operator1 = typename Policy1::Operator; /// Shape describing the overall GEMM computed from shared memory /// by each warp. - using WarpGemm = typename Policy::Operator::Shape; + using WarpGemm = typename Policy0::Operator::Shape; /// Shape describing the number of warps filling the CTA using WarpCount = GemmShape; + using TensorRefA = TensorRef; /// Tensor reference to the B operand - using TensorRefB = TensorRef; + using TensorRefB0 = TensorRef; + using TensorRefB1 = TensorRef; static_assert(kWarpGemmIterations > 1, "The pipelined structure requires at least two warp-level " @@ -119,14 +124,17 @@ class DualMmaBase { // /// Shape of the A matrix operand in shared memory - using ShapeA = MatrixShape; + Policy0::SmemPaddingA::kColumn>; /// Shape of the B matrix operand in shared memory - using ShapeB = - MatrixShape; + using ShapeB0 = + MatrixShape; + using ShapeB1 = + MatrixShape; public: // @@ -134,11 +142,11 @@ class DualMmaBase { // /// Buffer for A operand - AlignedBuffer operand_A; + AlignedBuffer operand_A; /// Buffer for B operand - AlignedBuffer operand_B0; - AlignedBuffer operand_B1; + AlignedBuffer operand_B0; + AlignedBuffer operand_B1; public: @@ -148,14 +156,20 @@ class DualMmaBase { /// Returns a layout object for the A matrix CUTLASS_DEVICE - static typename Operator::LayoutA LayoutA() { - return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + static typename Operator0::LayoutA LayoutA() { + return Operator0::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); } /// Returns a layout object for the B matrix CUTLASS_HOST_DEVICE - static typename Operator::LayoutB LayoutB() { - return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + static typename Operator0::LayoutB LayoutB0() { + return Operator0::LayoutB::packed({ShapeB0::kRow, ShapeB0::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator1::LayoutB LayoutB1() { + return Operator1::LayoutB::packed({ShapeB1::kRow, ShapeB1::kColumn}); } /// Returns a TensorRef to the A operand @@ -166,12 +180,12 @@ class DualMmaBase { /// Returns a TensorRef to the B operand CUTLASS_HOST_DEVICE - TensorRefB operand_B0_ref() { - return TensorRefB{operand_B0.data(), LayoutB()}; + TensorRefB0 operand_B0_ref() { + return TensorRefB0{operand_B0.data(), LayoutB0()}; } CUTLASS_HOST_DEVICE - TensorRefB operand_B1_ref() { - return TensorRefB{operand_B1.data(), LayoutB()}; + TensorRefB1 operand_B1_ref() { + return TensorRefB1{operand_B1.data(), LayoutB1()}; } }; @@ -182,11 +196,11 @@ class DualMmaBase { // /// Iterator to load a warp-scoped tile of A operand from shared memory - typename Operator::IteratorA warp_tile_iterator_A_; + typename Operator0::IteratorA warp_tile_iterator_A_; /// Iterator to load a warp-scoped tile of B operand from shared memory - typename Operator::IteratorB warp_tile_iterator_B0_; - typename Operator::IteratorB warp_tile_iterator_B1_; + typename Operator0::IteratorB warp_tile_iterator_B0_; + typename Operator1::IteratorB warp_tile_iterator_B1_; public: diff --git a/examples/45_dual_gemm/threadblock/dual_mma_multistage.h b/examples/45_dual_gemm/threadblock/dual_mma_multistage.h index 7843f2b2b..c486c69b9 100644 --- a/examples/45_dual_gemm/threadblock/dual_mma_multistage.h +++ b/examples/45_dual_gemm/threadblock/dual_mma_multistage.h @@ -67,21 +67,30 @@ template < typename SmemIteratorA_, /// Cache operation for operand A cutlass::arch::CacheOperation::Kind CacheOpA, - /// Iterates over tiles of B operand in global memory + /// Iterates over tiles of B0 operand in global memory // (concept: ReadableTileIterator | ForwardTileIterator | // MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory + typename IteratorB0_, + /// Iterates over tiles of B0 operand in shared memory /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, + typename SmemIteratorB0_, /// Cache operation for operand B cutlass::arch::CacheOperation::Kind CacheOpB, + /// Iterates over tiles of B1 operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B1 operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, /// Data type of accumulator matrix typename ElementC_, /// Data type of accumulator matrix typename LayoutC_, /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, + typename Policy0_, + /// B1-specific version of the policy (concept: MmaPolicy) + typename Policy1_, /// Number of stages, int Stages, /// Use zfill or predicate for out-of-bound cp.async @@ -89,25 +98,29 @@ template < /// Used for partial specialization typename Enable = bool> class DualMmaMultistage : - public DualMmaBase { + public DualMmaBase { public: ///< Base class - using Base = DualMmaBase; + using Base = DualMmaBase; ///< Size of the Gemm problem - concept: gemm::GemmShape<> using Shape = Shape_; ///< Iterates over tiles of A operand in global memory using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; + ///< Iterates over tiles of B0 operand in global memory + using IteratorB0 = IteratorB0_; + ///< Iterates over tiles of B1 operand in global memory + using IteratorB1 = IteratorB1_; ///< Data type of accumulator matrix using ElementC = ElementC_; ///< Layout of accumulator matrix using LayoutC = LayoutC_; ///< Policy describing tuning details - using Policy = Policy_; + using Policy0 = Policy0_; + using Policy1 = Policy1_; using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; + using SmemIteratorB0 = SmemIteratorB0_; + using SmemIteratorB1 = SmemIteratorB1_; static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; @@ -117,19 +130,21 @@ public: // /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; + using FragmentC = typename Policy0::Operator::FragmentC; /// Warp-level Mma - using Operator = typename Policy::Operator; + using Operator0 = typename Policy0::Operator; + using Operator1 = typename Policy1::Operator; /// Minimum architecture is Sm80 to support cp.async using ArchTag = arch::Sm80; /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; + static ComplexTransform const kTransformA = Operator0::kTransformA; /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; + static ComplexTransform const kTransformB0 = Operator0::kTransformB; + static ComplexTransform const kTransformB1 = Operator1::kTransformB; /// Internal structure exposed for introspection. struct Detail { @@ -140,7 +155,7 @@ public: /// Number of cp.async instructions to load one stage of operand B static int const AsyncCopyIterationsPerStageB = - IteratorB::ThreadMap::Iterations::kCount; + IteratorB0::ThreadMap::Iterations::kCount; /// Number of stages static int const kStages = Stages; @@ -156,10 +171,12 @@ public: private: - using WarpLoadedFragmentA = typename Operator::FragmentA; - using WarpLoadedFragmentB = typename Operator::FragmentB; - using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; - using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + using WarpLoadedFragmentA = typename Operator0::FragmentA; + using WarpLoadedFragmentB0 = typename Operator0::FragmentB; + using WarpLoadedFragmentB1 = typename Operator1::FragmentB; + using WarpTransformedFragmentA = typename Operator0::TransformedFragmentA; + using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB; + using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; private: @@ -171,8 +188,8 @@ public: SmemIteratorA smem_iterator_A_; /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B0_; - SmemIteratorB smem_iterator_B1_; + SmemIteratorB0 smem_iterator_B0_; + SmemIteratorB1 smem_iterator_B1_; public: @@ -215,7 +232,7 @@ public: } CUTLASS_DEVICE - void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B0, IteratorB &iterator_B1, + void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB0 &iterator_B0, IteratorB1 &iterator_B1, int group_start_A = 0, int group_start_B = 0) { iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); @@ -253,9 +270,9 @@ public: } iterator_B0.set_iteration_index(group_start_B * - IteratorB::kAccessesPerVector); + IteratorB0::kAccessesPerVector); iterator_B1.set_iteration_index(group_start_B * - IteratorB::kAccessesPerVector); + IteratorB1::kAccessesPerVector); this->smem_iterator_B0_.set_iteration_index(group_start_B); this->smem_iterator_B1_.set_iteration_index(group_start_B); @@ -263,16 +280,16 @@ public: CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( this->smem_iterator_B0_.get()); - int const kSrcBytes = sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; + int const kSrcBytes = sizeof_bits::value * + IteratorB0::ThreadMap::kElementsPerAccess / + IteratorB0::kAccessesPerVector / 8; CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) { auto gmem_ptr = iterator_B0.get(); if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { @@ -292,16 +309,16 @@ public: CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( this->smem_iterator_B1_.get()); - int const kSrcBytes = sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; + int const kSrcBytes = sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { auto gmem_ptr = iterator_B1.get(); if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { @@ -330,8 +347,8 @@ public: ///< iterator over A operand in global memory IteratorA iterator_A, ///< iterator over B operand in global memory - IteratorB iterator_B0, - IteratorB iterator_B1, + IteratorB0 iterator_B0, + IteratorB1 iterator_B1, ///< initial value of accumulator FragmentC const &src_accum0, FragmentC const &src_accum1 @@ -386,16 +403,16 @@ public: // Async Copy for operand B0 CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( this->smem_iterator_B0_.get()); CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) { int const kSrcBytes = - sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; + sizeof_bits::value * + IteratorB0::ThreadMap::kElementsPerAccess / + IteratorB0::kAccessesPerVector / 8; cutlass::arch::cp_async_zfill( dst_ptr + v, iterator_B0.get(), iterator_B0.valid()); @@ -408,16 +425,16 @@ public: // Async Copy for operand B1 CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( this->smem_iterator_B1_.get()); CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { int const kSrcBytes = - sizeof_bits::value * - IteratorB::ThreadMap::kElementsPerAccess / - IteratorB::kAccessesPerVector / 8; + sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; cutlass::arch::cp_async_zfill( dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); @@ -473,35 +490,35 @@ public: ++last_smem_iterator_A; } - typename IteratorB::AccessType zero_B; + typename IteratorB0::AccessType zero_B; zero_B.clear(); /// Iterator to write threadblock-scoped tile of B0 operand to shared memory - SmemIteratorB last_smem_iterator_B0(this->smem_iterator_B0_); + SmemIteratorB0 last_smem_iterator_B0(this->smem_iterator_B0_); last_smem_iterator_B0.set_iteration_index(0); - // Async Copy for operand B + // Async Copy for operand B0 CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( last_smem_iterator_B0.get()); *dst_ptr = zero_B; ++last_smem_iterator_B0; } + /// Iterator to write threadblock-scoped tile of B1 operand to shared memory - SmemIteratorB last_smem_iterator_B1(this->smem_iterator_B1_); + SmemIteratorB1 last_smem_iterator_B1(this->smem_iterator_B1_); last_smem_iterator_B1.set_iteration_index(0); - // Async Copy for operand B + // Async Copy for operand B1 CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { - typename IteratorB::AccessType *dst_ptr = - reinterpret_cast( + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( last_smem_iterator_B1.get()); *dst_ptr = zero_B; @@ -517,13 +534,14 @@ public: // Pair of fragments used to overlap shared memory loads and math // instructions WarpLoadedFragmentA warp_loaded_frag_A[2]; - WarpLoadedFragmentB warp_loaded_frag_B0[2]; - WarpLoadedFragmentB warp_loaded_frag_B1[2]; + WarpLoadedFragmentB0 warp_loaded_frag_B0[2]; + WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; WarpTransformedFragmentA warp_transformed_frag_A[2]; - WarpTransformedFragmentB warp_transformed_frag_B0[2]; - WarpTransformedFragmentB warp_transformed_frag_B1[2]; + WarpTransformedFragmentB0 warp_transformed_frag_B0[2]; + WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; - Operator warp_mma; + Operator0 warp_mma0; + Operator1 warp_mma1; this->warp_tile_iterator_A_.set_kgroup_index(0); this->warp_tile_iterator_B0_.set_kgroup_index(0); @@ -544,10 +562,10 @@ public: int smem_write_stage_idx = Base::kStages - 1; int smem_read_stage_idx = 0; - warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B0[0], - warp_loaded_frag_A[0], warp_loaded_frag_B0[0]); - warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B1[0], - warp_loaded_frag_A[0], warp_loaded_frag_B1[0]); + warp_mma0.transform(warp_transformed_frag_A[0], warp_transformed_frag_B0[0], + warp_loaded_frag_A[0], warp_loaded_frag_B0[0]); + warp_mma1.transform(warp_transformed_frag_A[0], warp_transformed_frag_B1[0], + warp_loaded_frag_A[0], warp_loaded_frag_B1[0]); // tf32x3 kernels use staging accumulation. warp_mma uses a temporary // accumulator and this temporary accumulator is added to the final @@ -556,9 +574,9 @@ public: FragmentC tmp_accum0, tmp_accum1; - if (platform::is_same::value - || platform::is_same::value) { tmp_accum0.clear(); @@ -597,28 +615,28 @@ public: ++this->warp_tile_iterator_B1_; if (warp_mma_k > 0) { - warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], - warp_transformed_frag_B0[warp_mma_k % 2], - warp_loaded_frag_A[warp_mma_k % 2], - warp_loaded_frag_B0[warp_mma_k % 2]); - warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], - warp_transformed_frag_B1[warp_mma_k % 2], - warp_loaded_frag_A[warp_mma_k % 2], - warp_loaded_frag_B1[warp_mma_k % 2]); + warp_mma0.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B0[warp_mma_k % 2]); + warp_mma1.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B1[warp_mma_k % 2]); } - if (platform::is_same::value - || platform::is_same::value) { - warp_mma( + warp_mma0( tmp_accum0, warp_transformed_frag_A[warp_mma_k % 2], warp_transformed_frag_B0[warp_mma_k % 2], tmp_accum0 ); - warp_mma( + warp_mma1( tmp_accum1, warp_transformed_frag_A[warp_mma_k % 2], warp_transformed_frag_B1[warp_mma_k % 2], @@ -632,13 +650,13 @@ public: tmp_accum1.clear(); } } else { - warp_mma( + warp_mma0( accum0, warp_transformed_frag_A[warp_mma_k % 2], warp_transformed_frag_B0[warp_mma_k % 2], accum0 ); - warp_mma( + warp_mma1( accum1, warp_transformed_frag_A[warp_mma_k % 2], warp_transformed_frag_B1[warp_mma_k % 2], @@ -696,14 +714,14 @@ public: if (smem_read_stage_idx == (Base::kStages - 1)) { this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * + {0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations}); this->warp_tile_iterator_B0_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * + {-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations, 0}); this->warp_tile_iterator_B1_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * + {-Base::kStages * Policy1::kPartitionsK * Base::kWarpGemmIterations, 0}); smem_read_stage_idx = 0; @@ -720,22 +738,22 @@ public: // Do any conversions feeding the first stage at the end of the loop so // we can start right away on mma instructions if (warp_mma_k + 1 == Base::kWarpGemmIterations) { - warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], - warp_transformed_frag_B0[(warp_mma_k + 1) % 2], - warp_loaded_frag_A[(warp_mma_k + 1) % 2], - warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); - warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], - warp_transformed_frag_B1[(warp_mma_k + 1) % 2], - warp_loaded_frag_A[(warp_mma_k + 1) % 2], - warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + warp_mma0.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B0[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + warp_mma1.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); } } } - if (platform::is_same::value - || platform::is_same::value) { accum0 = plus_accum(accum0, tmp_accum0); accum1 = plus_accum(accum1, tmp_accum1);