From b079841b1009f791135ccd55998be21b66326d3a Mon Sep 17 00:00:00 2001 From: Erwin Terpstra Date: Tue, 20 Jan 2026 22:06:59 +0100 Subject: [PATCH] Implement batched gemm add relu gemm add for rdna4 (#3391) * wip: test suite for batched gemm multiple d gemm multiple d, working on gridwise implenentation * wip: many fixes in implementation of batched gemm gemm multiple d * wip: batched gemm gemm multiple d gridwise op compiling, not working yet * fix: incorrect d0 grid indexing in batched gemm gemm multipled * feat: add instances for batched gemm add relu gemm add * chore: configure instance with low vector transfer size for odd sizes * chore: add some more validation to device batched gemm gemm multiple d, and removed template parameter that didn't really make sense * fix: upate device_batched_gemm_gemm_wmma to work with new gridwise changes * fix: disable odd size tests on XDL archs * chore: removed temporary logging * chore: update some references to C tensor to E tensor * Tentative fix for example template params * Tentative fix for non-multi-D batched gemm gemm device impl. * Tentative fix for xdl example template params * Tentative fix for profiler build on gfx90a * chore: improve device batched gemm gemm multi D comment to include all ops and dimensions * chore: explicitly call ck::make_tuple to prevent issues when std::make_tuple would apply * fix: make the gemm1 data types match what happens in the device op * feat: add d0s/d1s datatypes and layouts to the device op type string * chore: change element-wise op so addition happens in fp32 * chore: add static asserts for gemm0/gemm1 calculated wave sizes * chore: also updated other element-wise ops to use fp32 calculations * chore: log number of supported instances * chore: update instance comment * chore: disable kernel timing in example by default * fix: gemm1 wave size calculation * fix: make sure batched gemm multiple d gemm multiple d profiler performs correct type conversions * chore: remove increased tolerance in batched gemm gemm multiple d example * chore: add comment explaining that verification fails for certain input values * chore: clarify instance comment --------- Co-authored-by: kiefer [ROCm/composable_kernel commit: d5ae81b2922773f7cdf4a02a2e1fd57d0e4df851] --- .../CMakeLists.txt | 1 + ...d_gemm_add_add_relu_gemm_add_wmma_fp16.cpp | 135 +++ ...ed_gemm_add_add_relu_gemm_add_xdl_fp16.cpp | 397 +----- ...atched_gemm_multiple_d_gemm_multiple_d.inc | 350 ++++++ .../element_ops.h | 58 + ...ice_batched_gemm_gemm_wmma_cshuffle_v3.hpp | 20 +- ...ple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp | 1072 +++++++++++++++++ ...ultiple_d_gemm_multiple_d_xdl_cshuffle.hpp | 10 + .../gpu/device/matrix_padder.hpp | 9 + ...ise_batched_gemm_gemm_wmma_cshuffle_v3.hpp | 529 ++++++-- .../transform_contraction_to_gemm.hpp | 24 + include/ck/utility/tuple_helper.hpp | 2 +- .../gpu/batched_gemm_add_relu_gemm_add.hpp | 63 +- .../CMakeLists.txt | 5 +- ...6_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp | 72 ++ ...6_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp | 72 ++ ...d_gemm_multiple_d_gemm_multiple_d_impl.hpp | 387 ++++++ test/CMakeLists.txt | 1 + .../CMakeLists.txt | 12 + .../test_batched_gemm_add_relu_gemm_add.cpp | 27 + ...atched_gemm_multiple_d_gemm_multiple_d.hpp | 121 ++ ...mm_multiple_d_gemm_multiple_d_ut_cases.inc | 88 ++ 22 files changed, 2956 insertions(+), 499 deletions(-) create mode 100644 example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_wmma_fp16.cpp create mode 100644 example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_multiple_d_gemm_multiple_d.inc create mode 100644 example/37_batched_gemm_add_add_relu_gemm_add/element_ops.h create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp create mode 100644 library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp create mode 100644 profiler/include/profiler/profile_batched_gemm_multiple_d_gemm_multiple_d_impl.hpp create mode 100644 test/batched_gemm_multiple_d_gemm_multiple_d/CMakeLists.txt create mode 100644 test/batched_gemm_multiple_d_gemm_multiple_d/test_batched_gemm_add_relu_gemm_add.cpp create mode 100644 test/batched_gemm_multiple_d_gemm_multiple_d/test_batched_gemm_multiple_d_gemm_multiple_d.hpp create mode 100644 test/batched_gemm_multiple_d_gemm_multiple_d/test_batched_gemm_multiple_d_gemm_multiple_d_ut_cases.inc diff --git a/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt b/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt index 7b6eb01413..769b0888df 100644 --- a/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt +++ b/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt @@ -2,3 +2,4 @@ # SPDX-License-Identifier: MIT add_example_executable(example_batched_gemm_add_add_relu_gemm_add_xdl_fp16 batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp) +add_example_executable(example_batched_gemm_add_add_relu_gemm_add_wmma_fp16 batched_gemm_add_add_relu_gemm_add_wmma_fp16.cpp) diff --git a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_wmma_fp16.cpp b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_wmma_fp16.cpp new file mode 100644 index 0000000000..bf12ee6c3e --- /dev/null +++ b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_wmma_fp16.cpp @@ -0,0 +1,135 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/* +Computes C_m_o = Relu(A0[m, k] * B0[n, k] + D00[m, n] + D01[mn]) * B1[n, o] + D1[m, o] +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" + +#include "element_ops.h" + +using ::ck::DeviceMem; +using ::ck::HostTensorDescriptor; +using ::ck::Tensor; + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using A0DataType = F16; +using B0DataType = F16; +using D00DataType = F16; +using D01DataType = F16; +using B1DataType = F16; +using D1DataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using E1DataType = F16; + +using A0Layout = Row; +using B0Layout = Col; +using D00Layout = Row; +using D01Layout = Row; +using B1Layout = Row; +using D1Layout = Row; +using E1Layout = Row; + +using A0ElementOp = PassThrough; +using B0ElementOp = PassThrough; +using CDE0ElementOp = AddAddRelu; +using A1ElementOp = PassThrough; +using B1ElementOp = PassThrough; +using CDE1ElementOp = ck::tensor_operation::element_wise::Add; + +constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = + ck::tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3< + A0Layout, + B0Layout, + ck::Tuple, + B1Layout, + ck::Tuple, + E1Layout, + A0DataType, + B0DataType, + ck::Tuple, + B1DataType, + ck::Tuple, + E1DataType, + AccDataType, + CShuffleDataType, + A0ElementOp, + B0ElementOp, + CDE0ElementOp, + B1ElementOp, + CDE1ElementOp, + GemmSpec, + + 32, // BlockSize + 16, // MPerBlock + 64, // LPerBlock + 64, // KPerBlock + 64, // NPerBlock (Gemm1NPerBlock) + 64, // LTilePerBlock (Gemm1KPerBlock) + 8, // AK1 + 8, // BK1 + 8, // L1 (B1K1) + 16, // MPerWmma + 16, // LPerWmma + 1, // MRepeat + 4, // LRepeat (Gemm0NRepeat) + 4, // NRepeat (Gemm1NRepeat) + + S<2, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + false, // ABlockLdsAddExtraM + S<2, 16, 1>, // B0BlockTransferThreadClusterLengths_K0_L_K1 + S<1, 0, 2>, // B0BlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // B0BlockTransferSrcAccessOrder + 2, // B0BlockTransferSrcVectorDim + 8, // B0BlockTransferSrcScalarPerVector + 8, // B0BlockTransferDstScalarPerVector_K1 + false, // B0BlockLdsAddExtraL + 4, // CDE0BlockTransferSrcScalarPerVector + S<2, 16, 1>, // B1BlockTransferThreadClusterLengths_L0_N_L1 + S<0, 2, 1>, // B1BlockTransferThreadClusterArrangeOrder + S<0, 2, 1>, // B1BlockTransferSrcAccessOrder + 1, // B1BlockTransferSrcVectorDim + 4, // B1BlockTransferSrcScalarPerVector + 2, // B1BlockTransferDstScalarPerVector_L1 + true, // B1BlockLdsAddExtraN + 1, // CShuffleMRepeatPerShuffle + 2, // CShuffleNRepeatPerShuffle + S<1, 16, 1, 2>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // CShuffleBlockTransferScalarPerVector_NPerBlock + +#include "batched_gemm_multiple_d_gemm_multiple_d.inc" +int main(int argc, char* argv[]) { return run_example(argc, argv); } diff --git a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp index ab87124c6b..60a1316907 100644 --- a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp +++ b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp @@ -22,6 +22,8 @@ Computes C_m_o = Relu(A0[m, k] * B0[n, k] + D00[m, n] + D01[mn]) * B1[n, o] + D1 #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "element_ops.h" + using ::ck::DeviceMem; using ::ck::HostTensorDescriptor; using ::ck::Tensor; @@ -39,11 +41,10 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; using A0DataType = F16; using B0DataType = F16; -using Acc0DataType = F32; +using AccDataType = F32; using D00DataType = F16; using D01DataType = F16; using B1DataType = F16; -using Acc1DataType = F32; using C1ShuffleDataType = F32; using D1DataType = F16; using E1DataType = F16; @@ -56,58 +57,6 @@ using B1Layout = Row; using D1Layout = Row; using E1Layout = Row; -// E = Relu(C + D0 + D1) -struct AddAddRelu -{ - __host__ __device__ void - operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d0, const ck::half_t& d1) const - { - const ck::half_t x = c + d0 + d1; - - ck::tensor_operation::element_wise::Relu{}.operator()(e, x); - } - __host__ __device__ void - operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const - { - const float x = c + (d0 + d1); - - ck::tensor_operation::element_wise::Relu{}.operator()(e, x); - } -}; - -// E = Gelu(C + D0 + D1) -struct AddAddGelu -{ - __host__ __device__ void - operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d0, const ck::half_t& d1) const - { - const ck::half_t x = c + d0 + d1; - - ck::tensor_operation::element_wise::Gelu{}.template operator()(e, - x); - } - - __host__ __device__ void - operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const - { - const float x = c + (d0 + d1); - - ck::tensor_operation::element_wise::Gelu{}.template operator()(e, x); - } -}; - -// E = FastGelu(C + D0 + D1) -struct AddAddFastGelu -{ - __host__ __device__ void - operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const - { - const float x = c + (d0 + d1); - - ck::tensor_operation::element_wise::FastGelu{}.template operator()(e, x); - } -}; - using A0ElementOp = PassThrough; using B0ElementOp = PassThrough; using CDE0ElementOp = AddAddRelu; @@ -131,10 +80,10 @@ using DeviceGemmInstance = E1Layout, A0DataType, B0DataType, - Acc0DataType, + AccDataType, ck::Tuple, B1DataType, - Acc1DataType, + AccDataType, C1ShuffleDataType, ck::Tuple, E1DataType, @@ -191,337 +140,5 @@ using DeviceGemmInstance = S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 4>; // CShuffleBlockTransferScalarPerVector_NPerBlock -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - // GEMM shape - ck::index_t M = 1024; - ck::index_t N = 1024; - ck::index_t K = 64; - ck::index_t O = 128; - ck::index_t BatchCount = 4; - ck::index_t StrideA0 = -1; - ck::index_t StrideB0 = -1; - ck::index_t StrideD00 = -1; - ck::index_t StrideD01 = -1; - ck::index_t StrideB1 = -1; - ck::index_t StrideD1 = -1; - ck::index_t StrideE1 = -1; - ck::index_t BatchStrideA0 = -1; - ck::index_t BatchStrideB0 = -1; - ck::index_t BatchStrideD00 = -1; - ck::index_t BatchStrideD01 = -1; - ck::index_t BatchStrideB1 = -1; - ck::index_t BatchStrideD1 = -1; - ck::index_t BatchStrideE1 = -1; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 9) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - O = std::stoi(argv[7]); - - BatchCount = std::stoi(argv[8]); - } - else if(argc == 23) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - O = std::stoi(argv[7]); - - BatchCount = std::stoi(argv[8]); - - StrideA0 = std::stoi(argv[9]); - StrideB0 = std::stoi(argv[10]); - StrideD00 = std::stoi(argv[11]); - StrideD01 = std::stoi(argv[12]); - StrideB1 = std::stoi(argv[13]); - StrideD1 = std::stoi(argv[14]); - StrideE1 = std::stoi(argv[15]); - - BatchStrideA0 = std::stoi(argv[16]); - BatchStrideB0 = std::stoi(argv[17]); - BatchStrideD00 = std::stoi(argv[18]); - BatchStrideD01 = std::stoi(argv[19]); - BatchStrideB1 = std::stoi(argv[20]); - BatchStrideD1 = std::stoi(argv[21]); - BatchStrideE1 = std::stoi(argv[22]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 8: M, N, K, O, Batch\n"); - printf( - "arg9 to 15: StrideA0, StrideB0, StrideD00, StrideD01, StrideB1, StrideD1, StrideE1\n"); - printf("arg16 to 22: BatchStrideA0, BatchStrideB0, BatchStrideD00, BatchStrideD01, " - "BatchStrideB1, BatchStrideD1, BatchStrideE1 \n"); - exit(0); - } - - const int DefaultStrideA0 = ck::is_same_v ? K : M; - const int DefaultStrideB0 = ck::is_same_v ? N : K; - const int DefaultStrideD00 = ck::is_same_v ? N : M; - const int DefaultStrideD01 = ck::is_same_v ? N : M; - const int DefaultStrideB1 = ck::is_same_v ? O : N; - const int DefaultStrideD1 = ck::is_same_v ? O : M; - const int DefaultStrideE1 = ck::is_same_v ? O : M; - - StrideA0 = (StrideA0 < 0) ? DefaultStrideA0 : StrideA0; - StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0; - StrideD00 = (StrideD00 < 0) ? DefaultStrideD00 : StrideD00; - StrideD01 = (StrideD01 < 0) ? DefaultStrideD01 : StrideD01; - StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1; - StrideD1 = (StrideD1 < 0) ? DefaultStrideD1 : StrideD1; - StrideE1 = (StrideE1 < 0) ? DefaultStrideE1 : StrideE1; - - const int DefaultBatchStrideA0 = (ck::is_same_v ? K : M) * StrideA0; - const int DefaultBatchStrideB0 = (ck::is_same_v ? N : K) * StrideB0; - const int DefaultBatchStrideD00 = (ck::is_same_v ? N : M) * StrideD00; - const int DefaultBatchStrideD01 = (ck::is_same_v ? N : M) * StrideD01; - const int DefaultBatchStrideB1 = (ck::is_same_v ? O : N) * StrideB1; - const int DefaultBatchStrideD1 = (ck::is_same_v ? O : M) * StrideD1; - const int DefaultBatchStrideE1 = (ck::is_same_v ? O : M) * StrideE1; - - BatchStrideA0 = BatchStrideA0 < 0 ? DefaultBatchStrideA0 : BatchStrideA0; - BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0; - BatchStrideD00 = BatchStrideD00 < 0 ? DefaultBatchStrideD00 : BatchStrideD00; - BatchStrideD01 = BatchStrideD01 < 0 ? DefaultBatchStrideD01 : BatchStrideD01; - BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1; - BatchStrideD1 = BatchStrideD1 < 0 ? DefaultBatchStrideD1 : BatchStrideD1; - BatchStrideE1 = BatchStrideE1 < 0 ? DefaultBatchStrideE1 : BatchStrideE1; - - auto f_host_tensor_descriptor = [](std::size_t batch_count, - std::size_t row, - std::size_t col, - std::size_t stride, - std::size_t batch_stride, - auto layout) { - using namespace ck::literals; - - if(std::is_same::value) - { - return HostTensorDescriptor( - {batch_count, row, col}, {batch_stride, stride, 1_uz}, layout); - } - else - { - return HostTensorDescriptor( - {batch_count, row, col}, {batch_stride, 1_uz, stride}, layout); - } - }; - - // E_m_o = A_m_k * B0_k_n * B1_n_o - Tensor a0_g_m_k( - f_host_tensor_descriptor(BatchCount, M, K, StrideA0, BatchStrideA0, A0Layout{})); - Tensor b0_g_k_n( - f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{})); - Tensor d00_g_m_n( - f_host_tensor_descriptor(BatchCount, M, N, StrideD00, BatchStrideD00, D00Layout{})); - Tensor d01_g_m_n( - f_host_tensor_descriptor(BatchCount, M, N, StrideD01, BatchStrideD01, D01Layout{})); - Tensor b1_g_n_o( - f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{})); - Tensor d1_g_m_o( - f_host_tensor_descriptor(BatchCount, M, O, StrideD1, BatchStrideD1, D1Layout{})); - Tensor e1_g_m_o_host_result( - f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{})); - Tensor e1_g_m_o_device_result( - f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{})); - - std::cout << "a0_g_m_k: " << a0_g_m_k.mDesc << std::endl; - std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl; - std::cout << "d00_g_m_n: " << d00_g_m_n.mDesc - << " size: " << d00_g_m_n.mDesc.GetElementSpaceSize() << std::endl; - std::cout << "d01_g_m_n: " << d01_g_m_n.mDesc - << " size: " << d01_g_m_n.mDesc.GetElementSpaceSize() << std::endl; - std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl; - std::cout << "e1_g_m_o: " << e1_g_m_o_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a0_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 3}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 3}); - d00_g_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 3}); - d01_g_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 3}); - b1_g_n_o.GenerateTensorValue(GeneratorTensor_2{-2, 3}); - d1_g_m_o.GenerateTensorValue(GeneratorTensor_2{-2, 3}); - break; - case 2: - a0_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - d00_g_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - d01_g_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b1_g_n_o.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d1_g_m_o.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - break; - default: - a0_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential{}); - d00_g_m_n.GenerateTensorValue(GeneratorTensor_1{1}); - d01_g_m_n.GenerateTensorValue(GeneratorTensor_1{1}); - b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); - d1_g_m_o.GenerateTensorValue(GeneratorTensor_1{1}); - } - - DeviceMem a0_g_m_k_device_buf(sizeof(A0DataType) * a0_g_m_k.mDesc.GetElementSize()); - DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize()); - DeviceMem d00_g_m_n_device_buf(sizeof(D00DataType) * d00_g_m_n.mDesc.GetElementSpaceSize()); - DeviceMem d01_g_m_n_device_buf(sizeof(D01DataType) * d01_g_m_n.mDesc.GetElementSpaceSize()); - DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize()); - DeviceMem e1_g_m_o_device_buf(sizeof(E1DataType) * - e1_g_m_o_device_result.mDesc.GetElementSize()); - DeviceMem d1_g_m_o_device_buf(sizeof(D1DataType) * d1_g_m_o.mDesc.GetElementSpaceSize()); - - a0_g_m_k_device_buf.ToDevice(a0_g_m_k.mData.data()); - b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); - d00_g_m_n_device_buf.ToDevice(d00_g_m_n.mData.data()); - d01_g_m_n_device_buf.ToDevice(d01_g_m_n.mData.data()); - b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data()); - d1_g_m_o_device_buf.ToDevice(d1_g_m_o.mData.data()); - - auto a0_element_op = A0ElementOp{}; - auto b0_element_op = B0ElementOp{}; - auto cde0_element_op = CDE0ElementOp{}; - auto b1_element_op = B1ElementOp{}; - auto cde1_element_op = CDE1ElementOp{}; - - // do GEMM - auto gemm = DeviceGemmInstance{}; - auto invoker = gemm.MakeInvoker(); - auto argument = - gemm.MakeArgument(static_cast(a0_g_m_k_device_buf.GetDeviceBuffer()), - static_cast(b0_g_k_n_device_buf.GetDeviceBuffer()), - std::array{d00_g_m_n_device_buf.GetDeviceBuffer(), - d01_g_m_n_device_buf.GetDeviceBuffer()}, - static_cast(b1_g_n_o_device_buf.GetDeviceBuffer()), - std::array{d1_g_m_o_device_buf.GetDeviceBuffer()}, - static_cast(e1_g_m_o_device_buf.GetDeviceBuffer()), - M, - N, - K, - O, - BatchCount, - StrideA0, - StrideB0, - std::array{StrideD00, StrideD01}, - StrideB1, - std::array{StrideD1}, - StrideE1, - BatchStrideA0, - BatchStrideB0, - std::array{BatchStrideD00, BatchStrideD01}, - BatchStrideB1, - std::array{BatchStrideD1}, - BatchStrideE1, - a0_element_op, - b0_element_op, - cde0_element_op, - b1_element_op, - cde1_element_op); - - if(!gemm.IsSupportedArgument(argument)) - { - std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; - - return 0; - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; - std::size_t num_btype = - (sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(D00DataType) * N + - sizeof(D01DataType) * N + sizeof(B1DataType) * N * O + sizeof(E1DataType) * M * O + - sizeof(D1DataType) * O) * - BatchCount; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << gemm.GetTypeString() << std::endl; - - e1_g_m_o_device_buf.FromDevice(e1_g_m_o_device_result.mData.data()); - - if(do_verification) - { - using ReferenceGemm0Instance = - ck::tensor_operation::host::ReferenceBatchedGemm; - - using ReferenceGemm1Instance = - ck::tensor_operation::host::ReferenceBatchedGemm; - - // Output of Gemm0 is input A of Gemm1 - Tensor c0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); - Tensor e0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); - Tensor c1_g_m_o(f_host_tensor_descriptor(BatchCount, M, O, O, M * O, Row{})); - - auto ref_gemm0 = ReferenceGemm0Instance{}; - auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); - auto ref_gemm0_argument = ref_gemm0.MakeArgument( - a0_g_m_k, b0_g_k_n, c0_g_m_n, a0_element_op, b0_element_op, PassThrough{}); - - ref_gemm0_invoker.Run(ref_gemm0_argument); - - // bias+bias+relu - e0_g_m_n.ForEach([&](auto&, auto idx) { - cde0_element_op(e0_g_m_n(idx), c0_g_m_n(idx), d00_g_m_n(idx), d01_g_m_n(idx)); - }); - - auto ref_gemm1 = ReferenceGemm1Instance{}; - auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); - auto ref_gemm1_argument = ref_gemm1.MakeArgument( - e0_g_m_n, b1_g_n_o, c1_g_m_o, PassThrough{}, b1_element_op, PassThrough{}); - - ref_gemm1_invoker.Run(ref_gemm1_argument); - - // bias - e1_g_m_o_host_result.ForEach([&](auto&, auto idx) { - cde1_element_op(e1_g_m_o_host_result(idx), c1_g_m_o(idx), d1_g_m_o(idx)); - }); - - return ck::utils::check_err(e1_g_m_o_device_result, e1_g_m_o_host_result) ? 0 : 1; - } - - return 0; -} +#include "batched_gemm_multiple_d_gemm_multiple_d.inc" +int main(int argc, char* argv[]) { return run_example(argc, argv); } diff --git a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_multiple_d_gemm_multiple_d.inc b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_multiple_d_gemm_multiple_d.inc new file mode 100644 index 0000000000..57bdd5e9ef --- /dev/null +++ b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_multiple_d_gemm_multiple_d.inc @@ -0,0 +1,350 @@ + +int run_example(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 256; + ck::index_t O = 512; + ck::index_t BatchCount = 4; + ck::index_t StrideA0 = -1; + ck::index_t StrideB0 = -1; + ck::index_t StrideD00 = -1; + ck::index_t StrideD01 = -1; + ck::index_t StrideB1 = -1; + ck::index_t StrideD1 = -1; + ck::index_t StrideE1 = -1; + ck::index_t BatchStrideA0 = -1; + ck::index_t BatchStrideB0 = -1; + ck::index_t BatchStrideD00 = -1; + ck::index_t BatchStrideD01 = -1; + ck::index_t BatchStrideB1 = -1; + ck::index_t BatchStrideD1 = -1; + ck::index_t BatchStrideE1 = -1; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 9) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + O = std::stoi(argv[7]); + + BatchCount = std::stoi(argv[8]); + } + else if(argc == 23) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + O = std::stoi(argv[7]); + + BatchCount = std::stoi(argv[8]); + + StrideA0 = std::stoi(argv[9]); + StrideB0 = std::stoi(argv[10]); + StrideD00 = std::stoi(argv[11]); + StrideD01 = std::stoi(argv[12]); + StrideB1 = std::stoi(argv[13]); + StrideD1 = std::stoi(argv[14]); + StrideE1 = std::stoi(argv[15]); + + BatchStrideA0 = std::stoi(argv[16]); + BatchStrideB0 = std::stoi(argv[17]); + BatchStrideD00 = std::stoi(argv[18]); + BatchStrideD01 = std::stoi(argv[19]); + BatchStrideB1 = std::stoi(argv[20]); + BatchStrideD1 = std::stoi(argv[21]); + BatchStrideE1 = std::stoi(argv[22]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 8: M, N, K, O, Batch\n"); + printf( + "arg9 to 15: StrideA0, StrideB0, StrideD00, StrideD01, StrideB1, StrideD1, StrideE1\n"); + printf("arg16 to 22: BatchStrideA0, BatchStrideB0, BatchStrideD00, BatchStrideD01, " + "BatchStrideB1, BatchStrideD1, BatchStrideE1 \n"); + exit(0); + } + + const int DefaultStrideA0 = ck::is_same_v ? K : M; + const int DefaultStrideB0 = ck::is_same_v ? N : K; + const int DefaultStrideD00 = ck::is_same_v ? N : M; + const int DefaultStrideD01 = ck::is_same_v ? N : M; + const int DefaultStrideB1 = ck::is_same_v ? O : N; + const int DefaultStrideD1 = ck::is_same_v ? O : M; + const int DefaultStrideE1 = ck::is_same_v ? O : M; + + StrideA0 = (StrideA0 < 0) ? DefaultStrideA0 : StrideA0; + StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0; + StrideD00 = (StrideD00 < 0) ? DefaultStrideD00 : StrideD00; + StrideD01 = (StrideD01 < 0) ? DefaultStrideD01 : StrideD01; + StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1; + StrideD1 = (StrideD1 < 0) ? DefaultStrideD1 : StrideD1; + StrideE1 = (StrideE1 < 0) ? DefaultStrideE1 : StrideE1; + + const int DefaultBatchStrideA0 = (ck::is_same_v ? K : M) * StrideA0; + const int DefaultBatchStrideB0 = (ck::is_same_v ? N : K) * StrideB0; + const int DefaultBatchStrideD00 = (ck::is_same_v ? N : M) * StrideD00; + const int DefaultBatchStrideD01 = (ck::is_same_v ? N : M) * StrideD01; + const int DefaultBatchStrideB1 = (ck::is_same_v ? O : N) * StrideB1; + const int DefaultBatchStrideD1 = (ck::is_same_v ? O : M) * StrideD1; + const int DefaultBatchStrideE1 = (ck::is_same_v ? O : M) * StrideE1; + + BatchStrideA0 = BatchStrideA0 < 0 ? DefaultBatchStrideA0 : BatchStrideA0; + BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0; + BatchStrideD00 = BatchStrideD00 < 0 ? DefaultBatchStrideD00 : BatchStrideD00; + BatchStrideD01 = BatchStrideD01 < 0 ? DefaultBatchStrideD01 : BatchStrideD01; + BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1; + BatchStrideD1 = BatchStrideD1 < 0 ? DefaultBatchStrideD1 : BatchStrideD1; + BatchStrideE1 = BatchStrideE1 < 0 ? DefaultBatchStrideE1 : BatchStrideE1; + + auto f_host_tensor_descriptor = [](std::size_t batch_count, + std::size_t row, + std::size_t col, + std::size_t stride, + std::size_t batch_stride, + auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, stride, 1_uz}, layout); + } + else + { + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, 1_uz, stride}, layout); + } + }; + + // E_m_o = A_m_k * B0_k_n * B1_n_o + Tensor a0_g_m_k( + f_host_tensor_descriptor(BatchCount, M, K, StrideA0, BatchStrideA0, A0Layout{})); + Tensor b0_g_k_n( + f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{})); + Tensor d00_g_m_n( + f_host_tensor_descriptor(BatchCount, M, N, StrideD00, BatchStrideD00, D00Layout{})); + Tensor d01_g_m_n( + f_host_tensor_descriptor(BatchCount, M, N, StrideD01, BatchStrideD01, D01Layout{})); + Tensor b1_g_n_o( + f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{})); + Tensor d1_g_m_o( + f_host_tensor_descriptor(BatchCount, M, O, StrideD1, BatchStrideD1, D1Layout{})); + Tensor e1_g_m_o_host_result( + f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{})); + Tensor e1_g_m_o_device_result( + f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{})); + + std::cout << "a0_g_m_k: " << a0_g_m_k.mDesc << std::endl; + std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl; + std::cout << "d00_g_m_n: " << d00_g_m_n.mDesc + << " size: " << d00_g_m_n.mDesc.GetElementSpaceSize() << std::endl; + std::cout << "d01_g_m_n: " << d01_g_m_n.mDesc + << " size: " << d01_g_m_n.mDesc.GetElementSpaceSize() << std::endl; + std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl; + std::cout << "e1_g_m_o: " << e1_g_m_o_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 3}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 3}); + d00_g_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 3}); + d01_g_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 3}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_2{-2, 3}); + d1_g_m_o.GenerateTensorValue(GeneratorTensor_2{-2, 3}); + break; + case 2: + a0_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d00_g_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d01_g_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d1_g_m_o.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + break; + default: + a0_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential{}); + d00_g_m_n.GenerateTensorValue(GeneratorTensor_1{1}); + d01_g_m_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); + d1_g_m_o.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem a0_g_m_k_device_buf(sizeof(A0DataType) * a0_g_m_k.mDesc.GetElementSize()); + DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize()); + DeviceMem d00_g_m_n_device_buf(sizeof(D00DataType) * d00_g_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d01_g_m_n_device_buf(sizeof(D01DataType) * d01_g_m_n.mDesc.GetElementSpaceSize()); + DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize()); + DeviceMem e1_g_m_o_device_buf(sizeof(E1DataType) * + e1_g_m_o_device_result.mDesc.GetElementSize()); + DeviceMem d1_g_m_o_device_buf(sizeof(D1DataType) * d1_g_m_o.mDesc.GetElementSpaceSize()); + + a0_g_m_k_device_buf.ToDevice(a0_g_m_k.mData.data()); + b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); + d00_g_m_n_device_buf.ToDevice(d00_g_m_n.mData.data()); + d01_g_m_n_device_buf.ToDevice(d01_g_m_n.mData.data()); + b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data()); + d1_g_m_o_device_buf.ToDevice(d1_g_m_o.mData.data()); + + auto a0_element_op = A0ElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto cde0_element_op = CDE0ElementOp{}; + auto b1_element_op = B1ElementOp{}; + auto cde1_element_op = CDE1ElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = + gemm.MakeArgument(static_cast(a0_g_m_k_device_buf.GetDeviceBuffer()), + static_cast(b0_g_k_n_device_buf.GetDeviceBuffer()), + std::array{d00_g_m_n_device_buf.GetDeviceBuffer(), + d01_g_m_n_device_buf.GetDeviceBuffer()}, + static_cast(b1_g_n_o_device_buf.GetDeviceBuffer()), + std::array{d1_g_m_o_device_buf.GetDeviceBuffer()}, + static_cast(e1_g_m_o_device_buf.GetDeviceBuffer()), + M, + N, + K, + O, + BatchCount, + StrideA0, + StrideB0, + std::array{StrideD00, StrideD01}, + StrideB1, + std::array{StrideD1}, + StrideE1, + BatchStrideA0, + BatchStrideB0, + std::array{BatchStrideD00, BatchStrideD01}, + BatchStrideB1, + std::array{BatchStrideD1}, + BatchStrideE1, + a0_element_op, + b0_element_op, + cde0_element_op, + b1_element_op, + cde1_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; + std::size_t num_btype = + (sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(D00DataType) * N + + sizeof(D01DataType) * N + sizeof(B1DataType) * N * O + sizeof(E1DataType) * M * O + + sizeof(D1DataType) * O) * + BatchCount; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + e1_g_m_o_device_buf.FromDevice(e1_g_m_o_device_result.mData.data()); + + if(do_verification) + { + using ReferenceGemm0Instance = + ck::tensor_operation::host::ReferenceBatchedGemm; + + using ReferenceGemm1Instance = + ck::tensor_operation::host::ReferenceBatchedGemm; + + // Output of Gemm0 is input A of Gemm1 + Tensor c0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); + Tensor e0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); + Tensor c1_g_m_o(f_host_tensor_descriptor(BatchCount, M, O, O, M * O, Row{})); + + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument( + a0_g_m_k, b0_g_k_n, c0_g_m_n, a0_element_op, b0_element_op, PassThrough{}); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + // bias+bias+relu + // Note that we also convert from AccDataType to A0DataType to match what the device + // operation does + e0_g_m_n.ForEach([&](auto&, auto idx) { + AccDataType out; + cde0_element_op(out, c0_g_m_n(idx), d00_g_m_n(idx), d01_g_m_n(idx)); + e0_g_m_n(idx) = ck::type_convert(out); + }); + + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument( + e0_g_m_n, b1_g_n_o, c1_g_m_o, PassThrough{}, b1_element_op, PassThrough{}); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // bias + e1_g_m_o_host_result.ForEach([&](auto&, auto idx) { + cde1_element_op(e1_g_m_o_host_result(idx), c1_g_m_o(idx), d1_g_m_o(idx)); + }); + + // NOTE: For float initialization (mode 2) verification currently fails due to inaccuracy. + // This seems to just be accumulating errors due to double gemm. It only seems to happen + // when using B1 tensor containing negative values, as this can get large values from gemm0 + // back to zero again but reduce the tolerance allowed by the relative tolerance. + // + // There doesn't seem to be any bug with the implementation, just a difference in order of + // operations between CPU and GPU causing an accumulating error. + bool validation_result = ck::utils::check_err(e1_g_m_o_device_result, e1_g_m_o_host_result); + std::cout << "Validation result: " << (validation_result ? "SUCCESS" : "FAIL") << "." + << std::endl; + + return validation_result ? 0 : 1; + } + + return 0; +} diff --git a/example/37_batched_gemm_add_add_relu_gemm_add/element_ops.h b/example/37_batched_gemm_add_add_relu_gemm_add/element_ops.h new file mode 100644 index 0000000000..e9f22b862e --- /dev/null +++ b/example/37_batched_gemm_add_add_relu_gemm_add/element_ops.h @@ -0,0 +1,58 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/ck.hpp" + +// E = Relu(C + D0 + D1) +struct AddAddRelu +{ + __host__ __device__ void + operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d0, const ck::half_t& d1) const + { + const ck::half_t x = c + d0 + d1; + + ck::tensor_operation::element_wise::Relu{}.operator()(e, x); + } + __host__ __device__ void + operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const + { + const float x = c + ck::type_convert(d0) + ck::type_convert(d1); + + ck::tensor_operation::element_wise::Relu{}.operator()(e, x); + } +}; + +// E = Gelu(C + D0 + D1) +struct AddAddGelu +{ + __host__ __device__ void + operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d0, const ck::half_t& d1) const + { + const ck::half_t x = c + d0 + d1; + + ck::tensor_operation::element_wise::Gelu{}.template operator()(e, + x); + } + + __host__ __device__ void + operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const + { + const float x = c + ck::type_convert(d0) + ck::type_convert(d1); + + ck::tensor_operation::element_wise::Gelu{}.template operator()(e, x); + } +}; + +// E = FastGelu(C + D0 + D1) +struct AddAddFastGelu +{ + __host__ __device__ void + operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const + { + const float x = c + ck::type_convert(d0) + ck::type_convert(d1); + + ck::tensor_operation::element_wise::FastGelu{}.template operator()(e, x); + } +}; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp index 39fe913206..45ec3a2065 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp @@ -20,6 +20,7 @@ #include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" +#include "ck/utility/tuple.hpp" namespace ck { namespace tensor_operation { @@ -51,12 +52,16 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) GridwiseOp::template Run( arg.p_a_grid + a_batch_offset, arg.p_b0_grid + b0_batch_offset, + Tuple<>{}, // p_d0s_grid arg.p_b1_grid + b1_batch_offset, + Tuple<>{}, // p_d1s_grid arg.p_c_grid + c_batch_offset, p_shared, arg.a_grid_desc, arg.b0_grid_desc, + Tuple<>{}, // D0sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock arg.b1_grid_desc, + Tuple<>{}, // D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock arg.c_grid_desc_mblock_mperblock_nblock_nperblock, arg.a_element_op, arg.b0_element_op, @@ -240,8 +245,10 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm, // Ds0DataType AccDataType, // Acc0DataType B1DataType, + Tuple<>, // Ds1DataType AccDataType, // Acc1DataType CShuffleDataType, CDataType, @@ -255,7 +262,9 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm, // Ds0GridDesc B1GridDesc, + Tuple<>, // Ds1GridDesc CGridDesc_M_N, // Tiling Family MPerBlock, @@ -290,6 +299,7 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm{}, arg.b1_grid_desc, + Tuple<>{}, arg.c_grid_desc_m_n, arg.block_2_ctile_map)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..06651c0c0e --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp @@ -0,0 +1,1072 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3(typename DeviceOp::RawArg arg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = + __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b0_batch_offset = + __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx))); + const long_index_t b1_batch_offset = + __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t e1_batch_offset = + __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetE1BasePtr(g_idx))); + + auto p_d0s_grid = GridwiseOp::MakeD0sGridPointer(); + auto p_d1s_grid = GridwiseOp::MakeD1sGridPointer(); + + static_for<0, DeviceOp::NumD0Tensor, 1>{}([&](auto In) { + const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(arg.compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In))); + p_d0s_grid(In) = arg.p_d0s_grid(In) + d0_batch_offset; + }); + + static_for<0, DeviceOp::NumD1Tensor, 1>{}([&](auto In) { + const long_index_t d1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(arg.compute_base_ptr_of_batch.GetD1BasePtr(g_idx, In))); + p_d1s_grid(In) = arg.p_d1s_grid(In) + d1_batch_offset; + }); + + GridwiseOp::template Run( + arg.p_a_grid + a_batch_offset, + arg.p_b0_grid + b0_batch_offset, + p_d0s_grid, + arg.p_b1_grid + b1_batch_offset, + p_d1s_grid, + arg.p_e1_grid + e1_batch_offset, + p_shared, + arg.a_grid_desc, + arg.b0_grid_desc, + arg.d0s_grid_desc, + arg.b1_grid_desc, + arg.d1s_grid_desc_mblock_mperblock_nblock_nperblock, + arg.e1_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op, + arg.b0_element_op, + arg.acc_element_op, + arg.b1_element_op, + arg.cde1_element_op, + arg.block_2_etile_map); +#else + ignore = arg; +#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__) +} + +// Computes: +// Acc = Acc_Op(A_Op(A) * B0_Op(B0), D0_0, D0_1, ...) +// E = CDE1_Op(Acc_Op(Acc0) * B1_Op(B1), D1_0, D1_1, ...) +// +// Dimensions: +// A = MK +// B0 = KL +// Acc/D0s = ML +// B1 = LN +// E/D1s = MN +template +struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3 + : public DeviceBatchedGemmMultipleDGemmMultipleD +{ + using DeviceOp = DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3; + + static constexpr index_t NumD0Tensor = D0sDataType::Size(); + static constexpr index_t NumD1Tensor = D1sDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + + // To match XDL implementation NPerWmma (A.k.a Gemm1 NPerWmma) is set equal + // to LPerWmma (A.k.a Gemm0 NPerWmma). + static constexpr index_t NPerWmma = LPerWmma; + + // TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler + // Transform operator or just not use one at all. + using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< + Sequence<1, 1, 1, 1, 1>, + Sequence, + GemmSpec, + TensorSpecialization::Default, // ASpec + TensorSpecialization::Default, // B0Spec + TensorSpecialization::Default, // B1Spec + TensorSpecialization::Default>; // CSpec + + __host__ __device__ static auto + MakeAGridDescriptor(const std::array& a_g_m_k_lengths_vec, + const std::array& a_g_m_k_strides_vec) + { + return Transform::MakeAGridDescriptor_AK0_M_AK1( + Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec), + Number{}); + } + + __host__ __device__ static auto + MakeB0GridDescriptor(const std::array& b0_g_l_k_lengths_vec, + const std::array& b0_g_l_k_strides_vec) + { + return Transform::MakeB0GridDescriptor_BK0_N_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec), + Number{}); + } + + __host__ __device__ static auto + MakeB1GridDescriptor(const std::array& b1_g_n_l_lengths_vec, + const std::array& b1_g_n_l_strides_vec) + { + return Transform::MakeB1GridDescriptor_BK0_N_BK1( + Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec), + Number{}); + } + + __host__ __device__ static auto + MakeD0GridDescriptor(const std::array& d0_g_m_n_lengths_vec, + const std::array& d0_g_m_n_strides_vec) + { + return Transform::MakeCGridDescriptor_M_N(d0_g_m_n_lengths_vec, d0_g_m_n_strides_vec); + } + + __host__ __device__ static auto MakeD0sGridDescriptor( + const std::array, NumD0Tensor>& d0_g_m_n_lengths_vec, + const std::array, NumD0Tensor>& d0_g_m_n_strides_vec) + { + return generate_tuple( + [&](auto i) { + return MakeD0GridDescriptor(d0_g_m_n_lengths_vec[i], d0_g_m_n_strides_vec[i]); + }, + Number{}); + } + + __host__ __device__ static auto MakeD1sGridDescriptor( + const std::array, NumD0Tensor>& d1_g_m_o_lengths_vec, + const std::array, NumD0Tensor>& d1_g_m_o_strides_vec) + { + return generate_tuple( + [&](auto i) { + return MakeE1GridDescriptor(d1_g_m_o_lengths_vec[i], d1_g_m_o_strides_vec[i]); + }, + Number{}); + } + + __host__ __device__ static auto + MakeE1GridDescriptor(const std::array& e1_g_m_n_lengths_vec, + const std::array& e1_g_m_n_strides_vec) + { + return Transform::MakeCGridDescriptor_M_N(e1_g_m_n_lengths_vec, e1_g_m_n_strides_vec); + } + + using AGridDesc = decltype(MakeAGridDescriptor({}, {})); + using B0GridDesc = decltype(MakeB0GridDescriptor({}, {})); + using D0sGridDesc = remove_cvref_t; + using B1GridDesc = decltype(MakeB1GridDescriptor({}, {})); + using D1sGridDesc = remove_cvref_t; + using E1GridDesc = decltype(MakeE1GridDescriptor({}, {})); + + struct ComputeBasePtrOfStridedBatch + { + ComputeBasePtrOfStridedBatch(index_t BatchStrideA0, + index_t BatchStrideB0, + std::array BatchStrideD0s, + index_t BatchStrideB1, + std::array BatchStrideD1s, + index_t BatchStrideE1) + : BatchStrideA0_(BatchStrideA0), + BatchStrideB0_(BatchStrideB0), + BatchStrideD0s_(BatchStrideD0s), + BatchStrideB1_(BatchStrideB1), + BatchStrideD1s_(BatchStrideD1s), + BatchStrideE1_(BatchStrideE1) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA0_); + } + + __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB0_); + } + + template + __host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx, + Number d1_idx) const + { + return g_idx * static_cast(BatchStrideD0s_[d1_idx]); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB1_); + } + + __host__ __device__ constexpr long_index_t GetE1BasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideE1_); + } + + template + __host__ __device__ constexpr auto GetD1BasePtr(index_t g_idx, Number d1_idx) const + { + return g_idx * static_cast(BatchStrideD1s_[d1_idx]); + } + + private: + index_t BatchStrideA0_; + index_t BatchStrideB0_; + std::array BatchStrideD0s_; + index_t BatchStrideB1_; + std::array BatchStrideD1s_; + index_t BatchStrideE1_; + }; + + // GridwiseOp + using GridwiseOp = GridwiseBatchedGemmGemm_wmma_cshuffle_v3< + // DataType Family + ADataType, + B0DataType, + D0sDataType, + AccDataType, // Acc0DataType + B1DataType, + D1sDataType, + AccDataType, // Acc1DataType + CShuffleDataType, + E1DataType, + // ElementwiseOp Family + AElementwiseOperation, + B0ElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CDE1ElementwiseOperation, + InMemoryDataOperationEnum::Set, + // InMemory Data Descriptor + AGridDesc, + B0GridDesc, + D0sGridDesc, + B1GridDesc, + D1sGridDesc, + E1GridDesc, + // Tiling Family + MPerBlock, + LPerBlock, + KPerBlock, + AK1, + BK1, + NPerBlock, + LTilePerBlock, + L1, + MPerWmma, + LPerWmma, + NPerWmma, + MRepeat, + LRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + true, + ABlockLdsAddExtraM, + B0BlockTransferThreadClusterLengths_K0_L_K1, + B0BlockTransferThreadClusterArrangeOrder, + B0BlockTransferSrcAccessOrder, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B0BlockTransferDstScalarPerVector_K1, + true, + B0BlockLdsAddExtraL, + CDE0BlockTransferSrcScalarPerVector, + B1BlockTransferThreadClusterLengths_L0_N_L1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_L1, + false, + B1BlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + Transform::matrix_padder.PadN, + BlkGemmPipeSched, + BlkGemmPipelineVer>; + + struct RawArg : public BaseArgument + { + using arr3 = std::array; + + RawArg(const ADataType* p_a_grid_, + const B0DataType* p_b0_grid_, + std::array p_d0s_grid_, + const B1DataType* p_b1_grid_, + std::array p_d1s_grid_, + E1DataType* p_e1_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t O_, + index_t Batch, + index_t StrideA, + index_t StrideB0, + std::array StrideD0s, + index_t StrideB1, + std::array StrideD1s, + index_t StrideE1, + index_t BatchStrideA, + index_t BatchStrideB0, + std::array BatchStrideD0s, + index_t BatchStrideB1, + std::array BatchStrideD1s, + index_t BatchStrideE1, + AElementwiseOperation a_element_op_, + B0ElementwiseOperation b0_element_op_, + AccElementwiseOperation acc_element_op_, + B1ElementwiseOperation b1_element_op_, + CDE1ElementwiseOperation cde1_element_op_) + : p_a_grid{p_a_grid_}, + p_b0_grid{p_b0_grid_}, + p_d0s_grid{}, + p_b1_grid{p_b1_grid_}, + p_d1s_grid{}, + p_e1_grid{p_e1_grid_}, + M{M_}, + N{N_}, + K{K_}, + O{O_}, + batch_count{Batch}, + a_element_op{a_element_op_}, + b0_element_op{b0_element_op_}, + acc_element_op{acc_element_op_}, + b1_element_op{b1_element_op_}, + cde1_element_op{cde1_element_op_}, + compute_base_ptr_of_batch{BatchStrideA, + BatchStrideB0, + BatchStrideD0s, + BatchStrideB1, + BatchStrideD1s, + BatchStrideE1} + { + + a_g_m_k_lengths = arr3{batch_count, M, K}; + a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K] + + b0_g_n_k_lengths = arr3{batch_count, N, K}; + b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K] + + b1_g_o_n_lengths = arr3{batch_count, O, N}; + b1_g_o_n_strides = + is_same_v + ? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O] + : arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N] + + e1_g_m_o_lengths = arr3{batch_count, M, O}; + e1_g_m_o_strides = arr3{BatchStrideE1, StrideE1, 1}; // C layout [batch_count, M, O] + + a_grid_desc = MakeAGridDescriptor(a_g_m_k_lengths, a_g_m_k_strides); + b0_grid_desc = MakeB0GridDescriptor(b0_g_n_k_lengths, b0_g_n_k_strides); + b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides); + e1_grid_desc_m_n = MakeE1GridDescriptor(e1_g_m_o_lengths, e1_g_m_o_strides); + e1_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e1_grid_desc_m_n); + block_2_etile_map = GridwiseOp::MakeDefaultBlock2ETileMap(e1_grid_desc_m_n, 1, 1); + + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + using D0DataType = remove_cvref_t>; + + // D0s layout [batch_count, M, N] + d0s_g_m_n_lengths[i] = arr3{batch_count, M, N}; + d0s_g_m_n_strides[i] = arr3{BatchStrideD0s[i], StrideD0s[i], 1}; + + // D0 pointer + p_d0s_grid(i) = static_cast(p_d0s_grid_[i]); + + // D0 desc + d0s_grid_desc(i) = MakeD0GridDescriptor(d0s_g_m_n_lengths[i], d0s_g_m_n_strides[i]); + }); + + static_for<0, NumD1Tensor, 1>{}([&](auto i) { + using D1DataType = remove_cvref_t>; + + // D1s layout [batch_count, M, O] + d1s_g_m_o_lengths[i] = arr3{batch_count, M, O}; + d1s_g_m_o_strides[i] = arr3{BatchStrideD1s[i], StrideD1s[i], 1}; + + // D1 pointer + p_d1s_grid(i) = static_cast(p_d1s_grid_[i]); + + // D1 desc + d1s_grid_desc(i) = MakeE1GridDescriptor(d1s_g_m_o_lengths[i], d1s_g_m_o_strides[i]); + }); + + d1s_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(d1s_grid_desc); + } + + // Pointers + const ADataType* p_a_grid; + const B0DataType* p_b0_grid; + typename GridwiseOp::D0sGridPointer p_d0s_grid; + const B1DataType* p_b1_grid; + typename GridwiseOp::D1sGridPointer p_d1s_grid; + E1DataType* p_e1_grid; + + // Raw Problem Size + index_t M; + index_t N; + index_t K; + index_t O; + index_t batch_count; + + arr3 a_g_m_k_lengths; + arr3 a_g_m_k_strides; + arr3 b0_g_n_k_lengths; + arr3 b0_g_n_k_strides; + std::array d0s_g_m_n_lengths; + std::array d0s_g_m_n_strides; + arr3 b1_g_o_n_lengths; + arr3 b1_g_o_n_strides; + std::array d1s_g_m_o_lengths; + std::array d1s_g_m_o_strides; + arr3 e1_g_m_o_lengths; + arr3 e1_g_m_o_strides; + + AElementwiseOperation a_element_op; + B0ElementwiseOperation b0_element_op; + AccElementwiseOperation acc_element_op; + B1ElementwiseOperation b1_element_op; + CDE1ElementwiseOperation cde1_element_op; + + // Grid descriptors and other mem calculators + AGridDesc a_grid_desc; + B0GridDesc b0_grid_desc; + D0sGridDesc d0s_grid_desc; + B1GridDesc b1_grid_desc; + D1sGridDesc d1s_grid_desc; + typename GridwiseOp::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + d1s_grid_desc_mblock_mperblock_nblock_nperblock; + + E1GridDesc e1_grid_desc_m_n; + typename GridwiseOp::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e1_grid_desc_mblock_mperblock_nblock_nperblock; + + typename GridwiseOp::DefaultBlock2ETileMap block_2_etile_map; + + ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch; + }; + + // check if DsLayout is supported + template + static constexpr bool CheckDLayout() + { + bool valid = true; + // iterate over DLayout tuple + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + // if RefLayout and DLayout are same, keep valid true, otherwise false + valid = valid && is_same_v; + }); + return valid; + } + + static bool IsSupportedArgument([[maybe_unused]] const RawArg& arg) + { + // Print lambda with env check and printf() style formmating. + const char* curFunc = __func__; + auto print = [&curFunc](const char* format, ...) -> void { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { +#if defined(__clang__) +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wformat-nonliteral" +#endif + va_list args; + va_start(args, format); + std::vfprintf(stdout, format, args); + va_end(args); +#if defined(__clang__) +#pragma clang diagnostic pop +#endif + std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n"; + } + }; + + if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported())) + { + print("DeviceOp: Arch err\n"); + return false; + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + print("DeviceOp: gfx 11 does not support fp8\n"); + return false; + } + } + + if constexpr(!(is_same_v || is_same_v)) + { + print("DeviceOp: Acc0 Type err\n"); + return false; + } + + if constexpr(!(is_same_v)) + { + print("DeviceOp: A layout must be Row\n"); + return false; + } + + if constexpr(!(is_same_v)) + { + print("DeviceOp: B0 layout must be Column\n"); + return false; + } + + if constexpr(!(CheckDLayout())) + { + print("DeviceOp: All D0s layout must be Row\n"); + return false; + } + + if constexpr(!(is_same_v || + is_same_v)) + { + print("DeviceOp: B1 layout must be Column or Row\n"); + return false; + } + + if constexpr(!(CheckDLayout())) + { + print("DeviceOp: All D1s layout must be Row\n"); + return false; + } + + if constexpr(!(is_same_v)) + { + print("DeviceOp: C layout must be Row\n"); + return false; + } + + // Other padding modes have not been tested and do not get checked individually. + if constexpr(GemmSpec != GemmSpecialization::Default && + GemmSpec != GemmSpecialization::MNKOPadding) + { + print("Padding mode must be default or MNKO\n"); + return false; + } + + // Per wmma dimensions not equal to 16 are very untested. + if constexpr(MPerWmma != 16 || LPerWmma != 16 || NPerWmma != 16) + { + print("M, L, N per Wmma must be 16\n"); + return false; + } + + if(!GridwiseOp::CheckValidity(arg.a_grid_desc, + arg.b0_grid_desc, + arg.d0s_grid_desc, + arg.b1_grid_desc, + arg.d1s_grid_desc, + arg.e1_grid_desc_m_n, + arg.block_2_etile_map)) + { + return false; + } + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N; + const auto cde0_extent_lowest = arg.N; // D0 tensors forced to be row-major + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O; + const auto cde1_extent_lowest = arg.O; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + cde0_extent_lowest % CDE0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + cde1_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + print("DeviceOp: Data Transfer Vector scalar err\n"); + return false; + } + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1]; + const auto b0_stride_lowest = + B0BlockTransferSrcVectorDim == 2 ? arg.b0_g_n_k_strides[2] : arg.b0_g_n_k_strides[1]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? arg.b1_g_o_n_strides[2] : arg.b1_g_o_n_strides[1]; + const auto e1_stride_lowest = arg.e1_g_m_o_strides[2]; + + // NOTE: We don't check D0s/D1s stride, as they are already forced to be row-major + // and the lowest dimension stride is hardcoded to 1 + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + e1_stride_lowest == 1)) + { + print("DeviceOp: Data Vectorize transfer err\n"); + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding)) + { + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::RawArg; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock); + const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock); + + const index_t grid_size = arg.batch_count * M0 * N0; + + auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) { + constexpr bool has_loop = decltype(has_main_k_block_loop)::value; + constexpr TailNumber tn = tail_number; + + const auto kernel = + kernel_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3; + + return launch_and_time_kernel( + stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg); + }; + + bool HasMainKBlockLoop = GridwiseOp::CalculateHasMainKBlockLoop(arg.K); + TailNumber TailNum = GridwiseOp::CalculateKBlockLoopTailNum(arg.K); + + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(HasMainKBlockLoop && TailNum == TailNumber::Full) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!HasMainKBlockLoop && TailNum == TailNumber::Full) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else + { + printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n"); + return 0.0f; + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(HasMainKBlockLoop && TailNum == TailNumber::Full) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!HasMainKBlockLoop && TailNum == TailNumber::Even) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else + { + printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n"); + return 0.0f; + } + } + else + { + printf("Invalid pipeline version!\n"); + return 0.0f; + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static auto MakeArgument(const ADataType* p_a0, + const B0DataType* p_b0, + std::array p_d0s, + const B1DataType* p_b1, + std::array p_d1s, + E1DataType* p_e1, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t Gemm1NRaw, + index_t Batch, + index_t StrideA0, + index_t StrideB0, + std::array StrideD0s, + index_t StrideB1, + std::array StrideD1s, + index_t StrideE1, + index_t BatchStrideA0, + index_t BatchStrideB0, + std::array BatchStrideD0s, + index_t BatchStrideB1, + std::array BatchStrideD1s, + index_t BatchStrideE1, + AElementwiseOperation a0_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation cde0_element_op, + B1ElementwiseOperation b1_element_op, + CDE1ElementwiseOperation cde1_element_op) + { + return RawArg{p_a0, p_b0, + p_d0s, p_b1, + p_d1s, p_e1, + MRaw, NRaw, + KRaw, Gemm1NRaw, + Batch, StrideA0, + StrideB0, StrideD0s, + StrideB1, StrideD1s, + StrideE1, BatchStrideA0, + BatchStrideB0, BatchStrideD0s, + BatchStrideB1, BatchStrideD1s, + BatchStrideE1, a0_element_op, + b0_element_op, cde0_element_op, + b1_element_op, cde1_element_op}; + } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b0, + std::array p_d0s, + const void* p_b1, + std::array p_d1s, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t O, + ck::index_t Batch, + ck::index_t StrideA, + ck::index_t StrideB0, + std::array StrideD0s, + ck::index_t StrideB1, + std::array StrideD1s, + ck::index_t StrideE1, + ck::index_t BatchStrideA, + ck::index_t BatchStrideB0, + std::array BatchStrideD0s, + ck::index_t BatchStrideB1, + std::array BatchStrideD1s, + ck::index_t BatchStrideE1, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CDE1ElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b0), + p_d0s, + static_cast(p_b1), + p_d1s, + static_cast(p_c), + M, + N, + K, + O, + Batch, + StrideA, + StrideB0, + StrideD0s, + StrideB1, + StrideD1s, + StrideE1, + BatchStrideA, + BatchStrideB0, + BatchStrideD0s, + BatchStrideB1, + BatchStrideD1s, + BatchStrideE1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + template + static constexpr const char* DataTypeToString() + { + if constexpr(std::is_same_v) + { + return "fp32"; + } + else if constexpr(std::is_same_v) + { + return "fp16"; + } + else if constexpr(std::is_same_v) + { + return "bf16"; + } + else if constexpr(std::is_same_v) + { + return "fp8"; + } + else if constexpr(std::is_same_v) + { + return "bf8"; + } + else if constexpr(std::is_same_v) + { + return "int32"; + } + else if constexpr(std::is_same_v) + { + return "int8"; + } + else if constexpr(std::is_same_v) + { + return "int4"; + } + else + { + return "unknown"; + } + } + + template + std::string DataTypeTupleToString() const + { + const auto string_types = generate_tuple( + [&](auto i) { + using ElementType = remove_cvref_t>; + return DataTypeToString(); + }, + Number{}); + + return TupleReduce<0, DataTypes::Size()>( + [&](std::string s, std::string a) { return a + ", " + s; }, string_types); + }; + + template + std::string LayoutTupleToString() const + { + const auto string_layouts = generate_tuple( + [&](auto i) { + using ElementLayout = remove_cvref_t>; + return std::string(1, ElementLayout::name[0]); + }, + Number{}); + + return TupleReduce<0, Layouts::Size()>([&](std::string s, std::string a) { return a + s; }, + string_layouts); + }; + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3" + << "() << " " + << "D1s: " << LayoutTupleToString() + << ", " + << "A " << DataTypeToString() << ", " + << "B0 " << DataTypeToString() << ", " + << "D0s (" << DataTypeTupleToString() << "), " + << "B1 " << DataTypeToString() << ", " + << "D1s (" << DataTypeTupleToString() << "), " + << "E1 " << DataTypeToString() << ", " + << "Acc " << DataTypeToString() << ", " + << "Cshuf " << DataTypeToString() << ", " + << BlockSize << ", " + << MPerBlock << ", " + << LPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << LTilePerBlock << ", " + << L1 << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << "BlkGemmPipelineScheduler: " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << "BlkGemmPipelineVersion: " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " + << "BlkGemmPipelinePrefetchStages: " + << GridwiseOp::BlockwiseGemmPipe::PrefetchStages + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp index 4410871ac1..0823ca5f17 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp @@ -825,6 +825,11 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle { if(!ck::is_xdl_wmma_supported()) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "wrong! XDL/WMMA not supported for these datatypes or operation sizes." + << std::endl; + } return false; } @@ -843,6 +848,11 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle CheckDLayout() && is_same_v)) { + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "wrong! Unsupported tensor layout combination." << std::endl; + } return false; } diff --git a/include/ck/tensor_operation/gpu/device/matrix_padder.hpp b/include/ck/tensor_operation/gpu/device/matrix_padder.hpp index 95e7bd367a..6ead8a955c 100644 --- a/include/ck/tensor_operation/gpu/device/matrix_padder.hpp +++ b/include/ck/tensor_operation/gpu/device/matrix_padder.hpp @@ -101,6 +101,15 @@ struct GemmGemmPadder b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence{}); } + // D0[M, N] + template + __host__ __device__ constexpr auto + PadD0Descriptor_N_K(const D0Desc_MRaw_NRaw& d0_desc_mraw_nraw) const + { + return PadTensorDescriptor( + d0_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence{}); + } + // B1[Gemm1N, Gemm1K] = B1[O, N] template __host__ __device__ constexpr auto diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp index 121ca258be..e30ddf5c1c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp @@ -13,31 +13,35 @@ #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" -#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" namespace ck { -// Gemm0: A [M x K] x B0 [K x L] = Acc [M x L] -// Gemm1: Acc [M x L] x B1 [L x N] = C [M x N] +// Gemm0: AccOp(A [M x K] x B0 [K x L], D0) = Acc [M x L] +// Gemm1: CDEOp1(Acc [M x L] x B1 [L x N], D1) = E [M x N] template @@ -94,6 +99,9 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 static constexpr auto I5 = Number<5>{}; static constexpr auto I6 = Number<6>{}; + static constexpr index_t NumD0Tensor = D0sDataType::Size(); + static constexpr index_t NumD1Tensor = D1sDataType::Size(); + static constexpr auto AK0 = Number{}; static constexpr auto AK1 = Number{}; static constexpr auto BK0 = Number{}; @@ -105,9 +113,19 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 static constexpr auto BL0 = Number{}; static constexpr auto BL1 = Number{}; - static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); - static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma); - static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WaveSize0 = BlockSize / (MWaves * LWaves); + static constexpr auto WaveSize1 = BlockSize / (MWaves * NWaves); + static constexpr auto WaveSize = WaveSize0; + + static_assert( + WaveSize0 == 32 || WaveSize0 == 64, + "Misconfigured wave parameters: BlockSize / (MWaves * LWaves) != 32/64 threads per wave"); + static_assert( + WaveSize1 == 32 || WaveSize1 == 64, + "Misconfigured wave parameters: BlockSize / (MWaves * NWaves) != 32/64 threads per wave"); static constexpr index_t KPerWmmaBlk = WmmaSelector::selected_wmma @@ -212,6 +230,52 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 return b1_block_copy_step; } + // ck::Tuple + static constexpr auto MakeD0sGridPointer() + { + return generate_tuple( + [&](auto i) { + using D0iDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + // ck::Tuple + static constexpr auto MakeD1sGridPointer() + { + return generate_tuple( + [&](auto i) { + using D1iDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __device__ static auto GetGemm0WaveIdx() + { + const index_t thread_id = get_thread_local_1d_id(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, LWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto GetGemm0WaveMNIdx(const index_t thread_id) + { + constexpr auto wave_threadid_to_mn_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(WaveSize / LPerWmma, LPerWmma))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + return wave_threadid_to_mn_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + template __host__ __device__ static constexpr auto MakeWmmaTileDescriptor(const BlockDesc&) { @@ -369,14 +433,14 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 // *Caution Here repeat is shuffle repeat GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() { - constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + constexpr auto c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = make_naive_tensor_descriptor_packed( make_tuple(I1, Number{}, I1, Number{})); - return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; + return c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; } __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() @@ -432,12 +496,14 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 true>())>; // TransposeC (must be true to work), C' = B' x A' // block_id to matrix tile idx (m0, n0) mapping is controlled by {M01, N01} - template + template __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, const B0GridDesc& b0_grid_desc, + const D0sGridDesc& d0s_grid_desc, const B1GridDesc& b1_grid_desc, - const CGridDesc_M_N& c_grid_desc_m_n, - const Block2CTileMap& block_2_ctile_map) + const D1sGridDesc& d1s_grid_desc, + const E1GridDesc& c_grid_desc_m_n, + const Block2ETileMap& block_2_etile_map) { // Print lambda with env check and printf() style formmating. const char* curFunc = __func__; @@ -482,6 +548,44 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 return false; } + bool d0s_desc_valid = true; + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + if(!(M == d0s_grid_desc[i].GetLength(I0) && L == d0s_grid_desc[i].GetLength(I1))) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + print("GridwiseOp: M/L Length err, A_M/B0_L = %d, %d | D0s_M/N = %d, %d\n", + M, + L, + d0s_grid_desc[i].GetLength(I0), + d0s_grid_desc[i].GetLength(I1)); + } + + d0s_desc_valid = false; + } + }); + + bool d1s_desc_valid = true; + static_for<0, NumD1Tensor, 1>{}([&](auto i) { + if(!(M == d1s_grid_desc[i].GetLength(I0) && N == d1s_grid_desc[i].GetLength(I1))) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + print("GridwiseOp: M/N Length err, A_M/N = %d, %d | D1s_M/N = %d, %d\n", + M, + N, + d1s_grid_desc[i].GetLength(I0), + d1s_grid_desc[i].GetLength(I1)); + } + d1s_desc_valid = false; + } + }); + + if(!(d0s_desc_valid && d1s_desc_valid)) + { + return false; + } + if(!(M % MPerBlock == 0 && L % LPerBlock == 0 && K % KPerBlock == 0 && N % NPerBlock == 0)) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) @@ -513,11 +617,11 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 return false; } - if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + if(!block_2_etile_map.CheckValidity(c_grid_desc_m_n)) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - print("GridwiseOp: invalid block_2_ctile_map\n"); + print("GridwiseOp: invalid block_2_etile_map\n"); } return false; } @@ -539,37 +643,94 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 } __host__ __device__ static constexpr auto - MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const E1GridDesc& e_grid_desc_m_n) { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); const auto MBlock = M / MPerBlock; const auto NBlock = N / NPerBlock; - const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( - c_grid_desc_m_n, + const auto e1_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), make_unmerge_transform(make_tuple(NBlock, Number{}))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); - return c_grid_desc_mblock_mperblock_nblock_nperblock; + return e1_grid_desc_mblock_mperblock_nblock_nperblock; } - // return block_id to C matrix tile idx (m0, n0) mapping - __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( - const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + // D0 desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeD0GridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs( + const D0GridDesc_M_N& d0_grid_desc_m_n) { - return BlockToCTileMap_M00_N0_M01Adapt( - c_grid_desc_m_n); + const auto M = d0_grid_desc_m_n.GetLength(I0); + const auto N = d0_grid_desc_m_n.GetLength(I1); + + constexpr auto wmma = + WmmaSelector::selected_wmma; + + return transform_tensor_descriptor( + d0_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / MPerBlock, MRepeat, MWaves, MPerWmma)), + make_unmerge_transform(make_tuple(N / LPerBlock, + LRepeat, + LWaves, + WaveSize / LPerWmma, + wmma.num_acc_vgprs_per_wave))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 3, 4>{}, Sequence<1, 5, 6, 7, 8>{})); } - using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; - using DefaultBlock2CTileMap = - remove_cvref_t; + // D0s desc for source in blockwise copy + __host__ __device__ static constexpr auto + MakeD0sGridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs( + const D0sGridDesc& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeD0GridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs( + ds_grid_desc_m_n[i]); + }, + Number{}); + } + // Ds desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDescriptor_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + // return block_id to E matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2ETileMap( + const E1GridDesc& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt(c_grid_desc_m_n); + } + + using E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + using D0sGridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs = + remove_cvref_t< + decltype(MakeD0sGridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs( + D0sGridDesc{}))>; + + using D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + using DefaultBlock2ETileMap = + remove_cvref_t; struct SharedMemTrait { @@ -600,45 +761,69 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 .GetElementSpaceSize(); }; + using D0sGridPointer = decltype(MakeD0sGridPointer()); + using D1sGridPointer = decltype(MakeD1sGridPointer()); + template + typename Block2ETileMap = DefaultBlock2ETileMap> __device__ static void Run(const ADataType* __restrict__ p_a_grid, const B0DataType* __restrict__ p_b0_grid, + D0sGridPointer p_d0s_grid, const B1DataType* __restrict__ p_b1_grid, - CDataType* __restrict__ p_c_grid, + D1sGridPointer p_d1s_grid, + E1DataType* __restrict__ p_e1_grid, void* __restrict__ p_shared, const AGridDesc& a_grid_desc, const B0GridDesc& b0_grid_desc, + const D0sGridDesc& d0s_grid_desc, const B1GridDesc& b1_grid_desc, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& - c_grid_desc_mblock_mperblock_nblock_nperblock, + const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + d1s_grid_desc_mblock_mperblock_nblock_nperblock, + const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + e1_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation& a_element_op, const B0ElementwiseOperation& b0_element_op, const AccElementwiseOperation& acc_element_op, const B1ElementwiseOperation& b1_element_op, - const CElementwiseOperation& c_element_op, - const Block2CTileMap& block_2_ctile_map) + const CDEElementwiseOperation& c_element_op, + const Block2ETileMap& block_2_etile_map) { // clang-format off /*******************************************************************************/ // Memory buffer zone. + const auto d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = + MakeD0sGridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(d0s_grid_desc); const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc.GetElementSpaceSize()); const auto b0_grid_buf = make_dynamic_buffer( p_b0_grid, b0_grid_desc.GetElementSpaceSize()); const auto b1_grid_buf = make_dynamic_buffer( p_b1_grid, b1_grid_desc.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + auto e1_grid_buf = make_dynamic_buffer( + p_e1_grid, e1_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + const auto d0s_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_d0s_grid[i], + d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs[i].GetElementSpaceSize()); + }, + Number{}); + const auto d1s_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_d1s_grid[i], + d1s_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); /*******************************************************************************/ // BlockIdx.x -> [BlockId.m, BlockId.n] - const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - if(!block_2_ctile_map.ValidCTileIndex( + const auto block_work_idx = block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + if(!block_2_etile_map.ValidCTileIndex( block_work_idx, - make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), - c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + make_tuple(e1_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e1_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) { return; } // Store BlockId into SGPR @@ -757,6 +942,72 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 constexpr auto lsubgroup = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5); constexpr auto laccvgprs = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6); + // d0 matrix threadwise copy + constexpr auto d0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = + make_naive_tensor_descriptor_packed(make_tuple( + I1, // MBlockId + I1, // NBlockID + mrepeat, + mwave, + mthreadpersubgroup, + lrepeat, + lwave, + lsubgroup, + laccvgprs)); + + auto d0s_thread_buf = generate_tuple( + [&](auto i) { + using D0iDataType = remove_cvref_t>; + return StaticBuffer< + AddressSpaceEnum::Vgpr, + D0iDataType, + d0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetElementSpaceSize(), + true>{}; + }, + Number{}); + + const auto wave_id = GetGemm0WaveIdx(); // I0: MWaves, I1: LWaves, I2: WaveSize + const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I0: WaveSize / LPerWmma, I1: LPerWmma + + static_assert(CDE0BlockTransferSrcScalarPerVector <= laccvgprs, + "vector load must be not greater than n4"); + static_assert(laccvgprs % CDE0BlockTransferSrcScalarPerVector == 0); + + auto d0s_threadwise_copy = generate_tuple( + [&](auto i) { + using D0iDataType = remove_cvref_t>; + return ThreadwiseTensorSliceTransfer_v2< + D0iDataType, + D0iDataType, + decltype(d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs[i]), + decltype(d0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs), + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>, + 8, // NOTE: XDL has this exposed as CDE0BlockTransferSrcVectorDim. + // But as the grid descriptor is built internally, the parameter doesn't really make sense to configure per instance + CDE0BlockTransferSrcScalarPerVector, + 1, + false>(d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs[i], + make_multi_index(block_work_idx[I0], // MBlockId + 0, // NBlockId + 0, // mrepeat + wave_id[I0], // mwave + wave_m_n_id[I1], // mthreadpersubgroup + 0, // nrepeat + wave_id[I1], // nwave + wave_m_n_id[I0], // nsubgroup + 0)); // register number + }, + Number{}); + constexpr auto acc0_thread_desc_l0perblock_mperblock_l1 = transform_tensor_descriptor( acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs, make_tuple(make_merge_transform_v3_division_mod(make_tuple(lrepeat, lwave, lsubgroup)), @@ -924,9 +1175,44 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 b_scale_struct, KBlockMainLoop, 1); // num_k_block_per_scale + // multiple d + if constexpr(NumD0Tensor) + { + constexpr auto d0s_thread_buf_size = d0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetElementSpaceSize(); - static_for<0, acc0_thread_buf.Size(), 1>{}( - [&](auto i) { acc_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); }); + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + d0s_threadwise_copy(i).Run(d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs[i], + d0s_grid_buf[i], + d0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0), + d0s_thread_buf(i)); + }); + static_for<0, d0s_thread_buf_size, 1>{}([&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { return d0s_thread_buf[iSrc][i]; }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto) -> auto& { return acc0_thread_buf(i); }, + Number<2>{}); + + unpack2(acc_element_op, dst_data_refs, src_data_refs); + }); + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + d0s_threadwise_copy(i).MoveSrcSliceWindow( + d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs[i], + make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0)); + }); + } + else + { + static_for<0, acc0_thread_buf.Size(), 1>{}( + [&](auto i) { acc_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); }); + } block_sync_lds(); @@ -995,15 +1281,15 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 } } // end gemm1 - constexpr auto c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = + constexpr auto c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(); - constexpr auto c_mrepeat = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I0); - constexpr auto c_mwave = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I1); - constexpr auto c_mthreadpersubgroup = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I2); - constexpr auto c_nrepeat = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I3); - constexpr auto c_nwave = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I4); - constexpr auto c_nsubgroup = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5); - constexpr auto c_naccvgprs = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6); + constexpr auto c_mrepeat = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I0); + constexpr auto c_mwave = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I1); + constexpr auto c_mthreadpersubgroup = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I2); + constexpr auto c_nrepeat = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I3); + constexpr auto c_nwave = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I4); + constexpr auto c_nsubgroup = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5); + constexpr auto c_naccvgprs = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6); constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed( make_tuple(c_mrepeat * c_mwave * c_mthreadpersubgroup, @@ -1032,29 +1318,29 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 /*******************************************************************************/ // write out to C, implement shuffle { - constexpr auto c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = + constexpr auto c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(); // This API Provide All dimension (size) you need - constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp = + constexpr auto c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp = blockwise_gemm1.GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(); - constexpr auto MWave = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I1); - constexpr auto MThreadPerSubGroup = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I2); - constexpr auto NWave = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I4); - constexpr auto NSubGroup = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I5); - constexpr auto NAccVgprs = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I6); + constexpr auto MWave = c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I1); + constexpr auto MThreadPerSubGroup = c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I2); + constexpr auto NWave = c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I4); + constexpr auto NSubGroup = c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I5); + constexpr auto NAccVgprs = c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I6); // LDS descriptor, shuffle and write out in MRepeat x NRepeat times - constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + constexpr auto c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); - auto c_shuffle_block_buf = make_dynamic_buffer( + auto c1_shuffle_block_buf = make_dynamic_buffer( static_cast(p_shared), - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize()); + c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize()); constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = transform_tensor_descriptor( - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, make_tuple( make_freeze_transform(I0), make_unmerge_transform(make_tuple( @@ -1097,10 +1383,10 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 make_multi_index(n_thread_data_on_block)); // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = + auto c1_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3 const auto& // return type should be reference + { return d1s_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c1_d1s_buf_refs = concat_tuple_of_reference( + tie(c1_shuffle_block_buf), + generate_tie([&](auto i) -> const auto& // return type should be reference + { return d1s_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c1_d1s_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, + auto cde1_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), D1sDataType{})), + Tuple, + decltype(e1_d1s_desc_refs), + decltype(tie(e1_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation, + Sequence(CGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray + // type + Sequence<1, CShuffleMRepeatPerShuffle * MWave * MPerWmma, 1, CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - CShuffleDataType, // typename SrcData, - CDataType, // typename DstData, - decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {e1_d1s_desc_refs, + idx_c1_d1s_block_begin, + tie(e1_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), c_element_op}; + // space filling curve for local reg & global memory // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = + constexpr auto sfc_c1_vgpr = SpaceFillingCurve, Sequence<0, 1, 2, 3, 4, 5, 6>, Sequence>{}; // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = + constexpr auto sfc_e1_global = SpaceFillingCurve, Sequence<0, 2, 1, 3>, Sequence<1, @@ -1174,37 +1492,44 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 1, CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{}; - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + constexpr index_t num_access = sfc_c1_vgpr.GetNumOfAccess(); - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + static_assert(num_access == sfc_e1_global.GetNumOfAccess(), "wrong!"); static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS block_sync_lds(); // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c1_thread_copy_vgpr_to_lds.Run(c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs, + sfc_c1_vgpr.GetIndexTupleOfNumber(access_id), c_thread_buf, c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs, - c_shuffle_block_buf); + c1_shuffle_block_buf); // make sure it's safe to read from LDS block_sync_lds(); // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); + cde1_shuffle_block_copy_lds_to_global.Run( + e1_d1s_desc_refs, + c1_d1s_buf_refs, + tie(e1_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e1_grid_buf)); if constexpr(access_id < num_access - 1) { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + constexpr auto e1_global_step = sfc_e1_global.GetForwardStep(access_id); + + // move on D1s + static_for<0, NumD1Tensor, 1>{}([&](auto i) { + cde1_shuffle_block_copy_lds_to_global.MoveSrcSliceWindow( + e1_d1s_desc_refs, i + I1, e1_global_step); + }); + // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + cde1_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + tie(e1_grid_desc_mblock_mperblock_nblock_nperblock), I0, e1_global_step); } }); } diff --git a/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp index ae21d5fa87..3fef73b134 100644 --- a/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp @@ -219,6 +219,30 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm make_tuple(Sequence<0, 2>{}, Sequence<1>{})); } + // + // D0 + // + static auto MakeD0GridDescriptorPair(const std::vector& d0_gs_ms_ns_lengths_vec, + const std::vector& d0_gs_ms_ns_strides_vec) + { + return MakeGridDescriptorPair(d0_gs_ms_ns_lengths_vec, + d0_gs_ms_ns_strides_vec); + } + + // TODO: rename to G_MRaw_NRaw + static auto MakeD0GridDescriptor_G_M_N(const std::vector& d0_gs_ms_ns_lengths_vec, + const std::vector& d0_gs_ms_ns_strides_vec) + { + return MakeD0GridDescriptorPair(d0_gs_ms_ns_lengths_vec, d0_gs_ms_ns_strides_vec).first; + } + + static auto MakeD0GridDescriptor_M_N(const std::vector& d0_gs_ms_ns_lengths_vec, + const std::vector& d0_gs_ms_ns_strides_vec) + { + return matrix_padder.PadD0Descriptor_M_N( + MakeD0GridDescriptorPair(d0_gs_ms_ns_lengths_vec, d0_gs_ms_ns_strides_vec).second); + } + // // B1 // diff --git a/include/ck/utility/tuple_helper.hpp b/include/ck/utility/tuple_helper.hpp index 294d7e7c7d..52ca5e9126 100644 --- a/include/ck/utility/tuple_helper.hpp +++ b/include/ck/utility/tuple_helper.hpp @@ -14,7 +14,7 @@ namespace ck { template __host__ __device__ constexpr auto generate_tuple_for(F&& f, Sequence) { - return make_tuple(f(Number{})...); + return ck::make_tuple(f(Number{})...); } template diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp index 7fba2ad412..133e52d651 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp @@ -20,6 +20,52 @@ namespace tensor_operation { namespace device { namespace instance { +#ifdef CK_USE_WMMA +#ifdef CK_ENABLE_FP16 +void add_device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( + std::vector, + Row, + ck::Tuple, + Row, + F16, + F16, + ck::Tuple, + F16, + ck::Tuple, + F16, + PassThrough, + PassThrough, + CDE0ElementOp, + PassThrough, + CDE1ElementOp>>>& + instances); + +void add_device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance( + std::vector, + Col, + ck::Tuple, + Row, + F16, + F16, + ck::Tuple, + F16, + ck::Tuple, + F16, + PassThrough, + PassThrough, + CDE0ElementOp, + PassThrough, + CDE1ElementOp>>>& + instances); +#endif // CK_ENABLE_FP16 +#endif // CK_USE_WMMA + +#ifdef CK_USE_XDL +#ifdef CK_ENABLE_FP16 void add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( std::vector>>& instances); - +#endif // CK_ENABLE_FP16 +#endif // CK_USE_XDL template > op_ptrs; +#ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( op_ptrs); +#endif +#ifdef CK_USE_WMMA + add_device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( + op_ptrs); +#endif } else if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { +#ifdef CK_USE_XDL add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance( op_ptrs); +#endif +#ifdef CK_USE_WMMA + add_device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance( + op_ptrs); +#endif } } +#endif return op_ptrs; } }; diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt index 5d830bb2fe..6c5d149133 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt @@ -1,8 +1,11 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_batched_gemm_add_relu_gemm_add_instance device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp + + device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp + device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp new file mode 100644 index 0000000000..8ddff48e39 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp @@ -0,0 +1,72 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu; +using CDE1ElementOp = ck::tensor_operation::element_wise::Add; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +using device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances = + std::tuple< + // clang-format off + //#####################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| D0DataType| B1Data| D1DataType| E1Data| AccData| CShuffle| A0| B0| CDE0| B1| CDE1| GemmSpecialization| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1|A0K1|B0K1| B1K1| MPer| NPer| MRepeat| LRepeat| NRepeat|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CDE0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer| + //#####################################################| | | | | | | Type| Type| | Type| | Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | Size| MPer| NPer| KPer| NPer| KPer| | | | WMMA| WMMA| | | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //#####################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Block| Block| Block| Block| Block| | | | | | | | |Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //#####################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // No padding + DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, ck::Tuple, F16, ck::Tuple, F16, F32, F32, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, GemmSpecialization::Default, 32, 16, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 4, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, true, 1, 2, S<1, 16, 1, 2>, 8>, + // Fallback with padding + DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3< Row, Col, ck::Tuple, Row, ck::Tuple, Row, F16, F16, ck::Tuple, F16, ck::Tuple, F16, F32, F32, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, GemmSpecialization::MNKOPadding, 32, 16, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, false, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, true, 1, 2, S<1, 16, 1, 2>, 1> + // clang-format on + >; + +void add_device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( + std::vector, + Row, + ck::Tuple, + Row, + F16, + F16, + ck::Tuple, + F16, + ck::Tuple, + F16, + PassThrough, + PassThrough, + CDE0ElementOp, + PassThrough, + CDE1ElementOp>>>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp new file mode 100644 index 0000000000..052abce92d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp @@ -0,0 +1,72 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu; +using CDE1ElementOp = ck::tensor_operation::element_wise::Add; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +using device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances = + std::tuple< + // clang-format off + //#####################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| D0DataType| B1Data| D1DataType| E1Data| AccData| CShuffle| A0| B0| CDE0| B1| CDE1| GemmSpecialization| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1|A0K1|B0K1| B1K1| MPer| NPer| MRepeat| LRepeat| NRepeat|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CDE0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer| + //#####################################################| | | | | | | Type| Type| | Type| | Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | Size| MPer| NPer| KPer| NPer| KPer| | | | WMMA| WMMA| | | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //#####################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Block| Block| Block| Block| Block| | | | | | | | |Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //#####################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // No padding + DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, ck::Tuple, F16, ck::Tuple, F16, F32, F32, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, GemmSpecialization::Default, 32, 16, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 2, true, 1, 2, S<1, 16, 1, 2>, 8>, + // Fallback with padding + DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3< Row, Col, ck::Tuple, Col, ck::Tuple, Row, F16, F16, ck::Tuple, F16, ck::Tuple, F16, F32, F32, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, GemmSpecialization::MNKOPadding, 32, 16, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, false, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 2, true, 1, 2, S<1, 16, 1, 2>, 1> + // clang-format on + >; + +void add_device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance( + std::vector, + Col, + ck::Tuple, + Row, + F16, + F16, + ck::Tuple, + F16, + ck::Tuple, + F16, + PassThrough, + PassThrough, + CDE0ElementOp, + PassThrough, + CDE1ElementOp>>>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_batched_gemm_multiple_d_gemm_multiple_d_impl.hpp b/profiler/include/profiler/profile_batched_gemm_multiple_d_gemm_multiple_d_impl.hpp new file mode 100644 index 0000000000..8bb3645164 --- /dev/null +++ b/profiler/include/profiler/profile_batched_gemm_multiple_d_gemm_multiple_d_impl.hpp @@ -0,0 +1,387 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_batched_gemm_multiple_d_gemm_multiple_d_impl( + bool do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int O, + int BatchCount = 1, + int StrideA0 = -1, + int StrideB0 = -1, + int StrideD0 = -1, + int StrideB1 = -1, + int StrideD1 = -1, + int StrideE1 = -1, + int BatchStrideA0 = -1, + int BatchStrideB0 = -1, + int BatchStrideD0 = -1, + int BatchStrideB1 = -1, + int BatchStrideD1 = -1, + int BatchStrideE1 = -1, + bool fail_if_no_supported_instance = false) + +{ + using Row = tensor_layout::gemm::RowMajor; + using Col = tensor_layout::gemm::ColumnMajor; + + using PassThrough = tensor_operation::element_wise::PassThrough; + + using D0DataType = remove_cvref_t>; + using D0Layout = remove_cvref_t>; + + using D1DataType = remove_cvref_t>; + using D1Layout = remove_cvref_t>; + + // for reference + using RefAcc0DataType = float; + using RefAcc1DataType = float; + + bool pass = true; + + const int DefaultStrideA0 = ck::is_same_v ? K : M; + const int DefaultStrideB0 = ck::is_same_v ? N : K; + const int DefaultStrideD0 = ck::is_same_v ? N : M; + const int DefaultStrideB1 = ck::is_same_v ? O : N; + const int DefaultStrideD1 = ck::is_same_v ? O : M; + const int DefaultStrideE1 = ck::is_same_v ? O : M; + + StrideA0 = (StrideA0 < 0) ? DefaultStrideA0 : StrideA0; + StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0; + StrideD0 = (StrideD0 < 0) ? DefaultStrideD0 : StrideD0; + StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1; + StrideD1 = (StrideD1 < 0) ? DefaultStrideD1 : StrideD1; + StrideE1 = (StrideE1 < 0) ? DefaultStrideE1 : StrideE1; + + const int DefaultBatchStrideA0 = (ck::is_same_v ? K : M) * StrideA0; + const int DefaultBatchStrideB0 = (ck::is_same_v ? N : K) * StrideB0; + const int DefaultBatchStrideD0 = (ck::is_same_v ? N : M) * StrideD0; + const int DefaultBatchStrideB1 = (ck::is_same_v ? O : N) * StrideB1; + const int DefaultBatchStrideD1 = (ck::is_same_v ? O : M) * StrideD1; + const int DefaultBatchStrideE1 = (ck::is_same_v ? O : M) * StrideE1; + + BatchStrideA0 = BatchStrideA0 < 0 ? DefaultBatchStrideA0 : BatchStrideA0; + BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0; + BatchStrideD0 = BatchStrideD0 < 0 ? DefaultBatchStrideD0 : BatchStrideD0; + BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1; + BatchStrideD1 = BatchStrideD1 < 0 ? DefaultBatchStrideD1 : BatchStrideD1; + BatchStrideE1 = BatchStrideE1 < 0 ? DefaultBatchStrideE1 : BatchStrideE1; + + auto f_host_tensor_descriptor = [](std::size_t batch_count, + std::size_t row, + std::size_t col, + std::size_t stride, + std::size_t batch_stride, + auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, stride, 1_uz}, layout); + } + else + { + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, 1_uz, stride}, layout); + } + }; + + // E_m_o = A_m_k * B0_k_n * B1_n_o + Tensor a0_g_m_k( + f_host_tensor_descriptor(BatchCount, M, K, StrideA0, BatchStrideA0, A0Layout{})); + Tensor b0_g_k_n( + f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{})); + Tensor d0_g_m_n( + f_host_tensor_descriptor(BatchCount, M, N, StrideD0, BatchStrideD0, D0Layout{})); + Tensor b1_g_n_o( + f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{})); + Tensor d1_g_m_o( + f_host_tensor_descriptor(BatchCount, M, O, StrideD1, BatchStrideD1, D1Layout{})); + Tensor e1_g_m_o_host_result( + f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{})); + Tensor e1_g_m_o_device_result( + f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{})); + + // Host verification: Output of Gemm0 is input A of Gemm1 + Tensor c0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); + Tensor e0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); + Tensor c1_g_m_o(f_host_tensor_descriptor(BatchCount, M, O, O, M * O, Row{})); + + std::cout << "a0_g_m_k: " << a0_g_m_k.mDesc << std::endl; + std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl; + std::cout << "d0_g_m_n: " << d0_g_m_n.mDesc << std::endl; + std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl; + std::cout << "d1_g_m_o: " << d1_g_m_o.mDesc << std::endl; + std::cout << "e1_g_m_o: " << e1_g_m_o_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 3}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 3}); + d0_g_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 3}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_2{-2, 3}); + d1_g_m_o.GenerateTensorValue(GeneratorTensor_2{-2, 3}); + break; + default: + a0_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_g_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d1_g_m_o.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + DeviceMem a0_g_m_k_device_buf(sizeof(A0DataType) * a0_g_m_k.mDesc.GetElementSize()); + DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize()); + DeviceMem d0_g_m_n_device_buf(sizeof(D0DataType) * d0_g_m_n.mDesc.GetElementSpaceSize()); + DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize()); + DeviceMem d1_g_m_o_device_buf(sizeof(D1DataType) * d1_g_m_o.mDesc.GetElementSpaceSize()); + DeviceMem e1_g_m_o_device_buf(sizeof(E1DataType) * + e1_g_m_o_device_result.mDesc.GetElementSize()); + + a0_g_m_k_device_buf.ToDevice(a0_g_m_k.mData.data()); + b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); + d0_g_m_n_device_buf.ToDevice(d0_g_m_n.mData.data()); + b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data()); + d1_g_m_o_device_buf.ToDevice(d1_g_m_o.mData.data()); + + auto a0_element_op = A0ElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto cde0_element_op = CDE0ElementOp{}; + auto b1_element_op = B1ElementOp{}; + auto cde1_element_op = CDE1ElementOp{}; + + using DeviceOp = + tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD; + + // get device op instances + const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + if(do_verification) + { + // Ref Gemm0 + using ReferenceGemm0Instance = tensor_operation::host::ReferenceBatchedGemm; + + // Ref Gemm1 + using ReferenceGemm1Instance = tensor_operation::host::ReferenceBatchedGemm; + + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument( + a0_g_m_k, b0_g_k_n, c0_g_m_n, a0_element_op, b0_element_op, PassThrough{}); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + // cde0_elementwise + // Note that we also convert from Acc0DataType to A0DataType to match what the device + // operation does + e0_g_m_n.ForEach([&](auto&, auto idx) { + RefAcc0DataType out; + cde0_element_op(out, c0_g_m_n(idx), d0_g_m_n(idx)); + + e0_g_m_n(idx) = ck::type_convert(out); + }); + + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument( + e0_g_m_n, b1_g_n_o, c1_g_m_o, PassThrough{}, b1_element_op, PassThrough{}); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // cde1_elementwise + e1_g_m_o_host_result.ForEach([&](auto&, auto idx) { + cde1_element_op(e1_g_m_o_host_result(idx), c1_g_m_o(idx), d1_g_m_o(idx)); + }); + } + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + int instances_supported = 0; + + // profile device op instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(a0_g_m_k_device_buf.GetDeviceBuffer()), + static_cast(b0_g_k_n_device_buf.GetDeviceBuffer()), + std::array{d0_g_m_n_device_buf.GetDeviceBuffer()}, + static_cast(b1_g_n_o_device_buf.GetDeviceBuffer()), + std::array{d1_g_m_o_device_buf.GetDeviceBuffer()}, + static_cast(e1_g_m_o_device_buf.GetDeviceBuffer()), + M, + N, + K, + O, + BatchCount, + StrideA0, + StrideB0, + std::array{StrideD0}, + StrideB1, + std::array{StrideD1}, + StrideE1, + BatchStrideA0, + BatchStrideB0, + std::array{BatchStrideD0}, + BatchStrideB1, + std::array{BatchStrideD1}, + BatchStrideE1, + a0_element_op, + b0_element_op, + cde0_element_op, + b1_element_op, + cde1_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + ++instances_supported; + + std::string op_name = op_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; + std::size_t num_btype = + (sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(D0DataType) * N + + sizeof(B1DataType) * N * O + sizeof(E1DataType) * M * O + sizeof(D1DataType) * O) * + BatchCount; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + e1_g_m_o_device_buf.FromDevice(e1_g_m_o_device_result.mData.data()); + + pass = pass & ck::utils::check_err(e1_g_m_o_device_result, e1_g_m_o_host_result); + + if(do_log) + { + LogRangeAsType( + std::cout << "e1_g_m_o_host_result : ", e1_g_m_o_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "e1_g_m_o_device_result : ", e1_g_m_o_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + } + + if(instances_supported == 0) + { + if(do_log) + { + std::cout << "Warning! No supported instances found." << std::endl; + } + + if(fail_if_no_supported_instance) + { + return false; + } + }; + + printf("\033[36mFound %d supported instances\033[0m\n", instances_supported); + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 9fee3b5697..ef2ac098ac 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -275,6 +275,7 @@ add_subdirectory(batched_contraction) add_subdirectory(batched_gemm) add_subdirectory(batched_gemm_reduce) add_subdirectory(batched_gemm_gemm) +add_subdirectory(batched_gemm_multiple_d_gemm_multiple_d) add_subdirectory(batched_gemm_softmax_gemm) add_subdirectory(batched_gemm_softmax_gemm_permute) add_subdirectory(batched_gemm_b_scale) diff --git a/test/batched_gemm_multiple_d_gemm_multiple_d/CMakeLists.txt b/test/batched_gemm_multiple_d_gemm_multiple_d/CMakeLists.txt new file mode 100644 index 0000000000..ee19891afd --- /dev/null +++ b/test/batched_gemm_multiple_d_gemm_multiple_d/CMakeLists.txt @@ -0,0 +1,12 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# NOTE: We test for XDL/WMMA support here instead of relying on the usual pattern matching in the parent CMakeLists. This is necessary +# as these tests are universal and dont have "xdl" or "wmma" in their name to signify their target arch. But they will fail to link +# the instance library if there's no instances present for the current arch. +if (CK_USE_XDL OR CK_USE_WMMA) + add_gtest_executable(test_batched_gemm_add_relu_gemm_add test_batched_gemm_add_relu_gemm_add.cpp) + if(result EQUAL 0) + target_link_libraries(test_batched_gemm_add_relu_gemm_add PRIVATE utility device_batched_gemm_add_relu_gemm_add_instance) + endif() +endif() \ No newline at end of file diff --git a/test/batched_gemm_multiple_d_gemm_multiple_d/test_batched_gemm_add_relu_gemm_add.cpp b/test/batched_gemm_multiple_d_gemm_multiple_d/test_batched_gemm_add_relu_gemm_add.cpp new file mode 100644 index 0000000000..6f224bfbb0 --- /dev/null +++ b/test/batched_gemm_multiple_d_gemm_multiple_d/test_batched_gemm_add_relu_gemm_add.cpp @@ -0,0 +1,27 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_batched_gemm_multiple_d_gemm_multiple_d.hpp" + +template +class TestBatchedGemmMultipleDGemmMultipleD + : public BaseTestBatchedGemmMultipleDGemmMultipleD +{ +}; + +using A0ElementOp = ck::tensor_operation::element_wise::PassThrough; +using B0ElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu; + +using B1ElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDE1ElementOp = ck::tensor_operation::element_wise::Add; + +// clang-format off +using KernelTypes = ::testing::Types< + std::tuple, F16, ck::Tuple, F16, Row, Col, ck::Tuple, Row, ck::Tuple, Row, A0ElementOp, B0ElementOp, CDE0ElementOp, B1ElementOp, CDE1ElementOp>, + std::tuple, F16, ck::Tuple, F16, Row, Col, ck::Tuple, Col, ck::Tuple, Row, A0ElementOp, B0ElementOp, CDE0ElementOp, B1ElementOp, CDE1ElementOp> + >; +// clang-format on + +TYPED_TEST_SUITE(TestBatchedGemmMultipleDGemmMultipleD, KernelTypes); +#include "test_batched_gemm_multiple_d_gemm_multiple_d_ut_cases.inc" diff --git a/test/batched_gemm_multiple_d_gemm_multiple_d/test_batched_gemm_multiple_d_gemm_multiple_d.hpp b/test/batched_gemm_multiple_d_gemm_multiple_d/test_batched_gemm_multiple_d_gemm_multiple_d.hpp new file mode 100644 index 0000000000..334dd69b19 --- /dev/null +++ b/test/batched_gemm_multiple_d_gemm_multiple_d/test_batched_gemm_multiple_d_gemm_multiple_d.hpp @@ -0,0 +1,121 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include + +#include "gtest/gtest.h" + +#include "ck/ck.hpp" +#include "ck/host_utility/device_prop.hpp" + +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "profiler/profile_batched_gemm_multiple_d_gemm_multiple_d_impl.hpp" + +using ck::tensor_operation::device::GemmSpecialization; + +template +using I = ck::Number; + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +struct BaseTestBatchedGemmMultipleDGemmMultipleD : public ::testing::Test +{ + using ADataType = std::tuple_element_t<0, Tuple>; + using B0DataType = std::tuple_element_t<1, Tuple>; + using D0sDataType = std::tuple_element_t<2, Tuple>; + using B1DataType = std::tuple_element_t<3, Tuple>; + using D1sDataType = std::tuple_element_t<4, Tuple>; + using EDataType = std::tuple_element_t<5, Tuple>; + using ALayout = std::tuple_element_t<6, Tuple>; + using B0Layout = std::tuple_element_t<7, Tuple>; + using D0sLayout = std::tuple_element_t<8, Tuple>; + using B1Layout = std::tuple_element_t<9, Tuple>; + using D1sLayout = std::tuple_element_t<10, Tuple>; + using ELayout = std::tuple_element_t<11, Tuple>; + using A0ElementOp = std::tuple_element_t<12, Tuple>; + using B0ElementOp = std::tuple_element_t<13, Tuple>; + using CDE0ElementOp = std::tuple_element_t<14, Tuple>; + using B1ElementOp = std::tuple_element_t<15, Tuple>; + using CDE1ElementOp = std::tuple_element_t<16, Tuple>; + + std::vector> lengths_ = { + {256, 256, 64, 64, 4}, + {256, 256, 128, 128, 4}, + {512, 512, 64, 64, 2}, + {512, 512, 128, 128, 2}, + {1024, 1024, 64, 64, 1}, + {1024, 1024, 128, 128, 1}, + }; + bool bench_ = false; + bool verify_ = true; + + void RunSingle(int M, int N, int K, int O, int BatchCount) + { + // WMMA instances are setup to support all the test cases + // XDL instances are not. + bool fail_if_no_supported_instances = ck::is_gfx11_supported() || ck::is_gfx12_supported(); + + bool pass = + ck::profiler::profile_batched_gemm_multiple_d_gemm_multiple_d_impl( + verify_, + 1, + false, + bench_, + M, + N, + K, + O, + BatchCount, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + -1, + fail_if_no_supported_instances); + + EXPECT_TRUE(pass); + } + + void Run() + { + for(auto lengths : this->lengths_) + { + int M = lengths[0]; + int N = lengths[1]; + int K = lengths[2]; + int O = lengths[3]; + int BatchCount = lengths[4]; + + this->RunSingle(M, N, K, O, BatchCount); + } + } +}; diff --git a/test/batched_gemm_multiple_d_gemm_multiple_d/test_batched_gemm_multiple_d_gemm_multiple_d_ut_cases.inc b/test/batched_gemm_multiple_d_gemm_multiple_d/test_batched_gemm_multiple_d_gemm_multiple_d_ut_cases.inc new file mode 100644 index 0000000000..0881e399d3 --- /dev/null +++ b/test/batched_gemm_multiple_d_gemm_multiple_d/test_batched_gemm_multiple_d_gemm_multiple_d_ut_cases.inc @@ -0,0 +1,88 @@ + +TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test) { this->Run(); } + +TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_PadM) +{ + this->lengths_ = std::vector>{ + {136, 128, 32, 128, 1}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_PadN) +{ + this->lengths_ = std::vector>{ + {128, 136, 32, 128, 1}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_PadK) +{ + this->lengths_ = std::vector>{ + {128, 128, 40, 128, 1}, + {128, 128, 136, 128, 1}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_PadO) +{ + this->lengths_ = std::vector>{ + {128, 128, 32, 136, 1}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_OddM) +{ + if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported())) + { + GTEST_SKIP() << "Odd-sizes only supported on WMMA instances."; + } + + this->lengths_ = std::vector>{ + {129, 128, 32, 128, 1}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_OddN) +{ + if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported())) + { + GTEST_SKIP() << "Odd-sizes only supported on WMMA instances."; + } + + this->lengths_ = std::vector>{ + {128, 129, 32, 128, 1}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_OddK) +{ + if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported())) + { + GTEST_SKIP() << "Odd-sizes only supported on WMMA instances."; + } + + this->lengths_ = std::vector>{ + {128, 128, 33, 128, 1}, + {128, 128, 129, 128, 1}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_OddO) +{ + if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported())) + { + GTEST_SKIP() << "Odd-sizes only supported on WMMA instances."; + } + + this->lengths_ = std::vector>{ + {128, 128, 32, 129, 1}, + }; + this->Run(); +}