diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp index 06b7c7d324..1ef4c06fdc 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -726,11 +726,10 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce #include #include "device.hpp" @@ -660,13 +658,12 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); - const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); - - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); float ave_time = 0; - if(has_main_k0_block_loop) + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { const auto kernel = kernel_gemm_xdlops_v3r2< GridwiseGemm, @@ -919,4 +916,3 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index 7f666b32ea..b508606a75 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -640,13 +640,12 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); - const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); - - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); float ave_time = 0; - if(has_main_k0_block_loop) + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { const auto kernel = kernel_gemm_xdlops_v3r1< GridwiseGemm, diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp index f334cb9c8d..3574f7667e 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -478,13 +478,12 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); - const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); - - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); float ave_time = 0; - if(has_main_k0_block_loop) + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { const auto kernel = kernel_gemm_xdlops_v2r3< GridwiseGemm, diff --git a/include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp index 9182b0ef1f..ff267c6cdf 100644 --- a/include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -1296,11 +1296,10 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]); - const auto K0 = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0); + const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) * + arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2); - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); - - if(has_main_k0_block_loop) + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { const auto kernel = kernel_gemm_xdlops_v2r3< GridwiseGemm, diff --git a/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp index b13466274f..ac62448386 100644 --- a/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -775,13 +775,12 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); - const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); - - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); float ave_time = 0; - if(has_main_k0_block_loop) + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { const auto kernel = kernel_gemm_xdlops_v2r3< GridwiseGemm, diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp index 8c02ddd3fd..915424bc37 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp @@ -530,11 +530,10 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce -#include -#include "device.hpp" -#include "device_gemm.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v3r1.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -template < - typename ADataType, - typename BDataType, - typename CDataType, - typename AccDataType, - typename CShuffleDataType, - typename ALayout, - typename BLayout, - typename CLayout, - typename AElementwiseOperation, - typename BElementwiseOperation, - typename CElementwiseOperation, - ck::index_t BlockSize, - ck::index_t MPerBlock, - ck::index_t NPerBlock, - ck::index_t KPerBlock, - ck::index_t AK1, - ck::index_t BK1, - ck::index_t MPerXDL, - ck::index_t NPerXDL, - ck::index_t MXdlPerWave, - ck::index_t NXdlPerWave, - typename ABlockTransferThreadClusterLengths_K0_M_K1, - typename ABlockTransferThreadClusterArrangeOrder, - typename ABlockTransferSrcAccessOrder, - ck::index_t ABlockTransferSrcVectorDim, - ck::index_t ABlockTransferSrcScalarPerVector, - ck::index_t ABlockTransferDstScalarPerVector_K1, - bool ABlockLdsAddExtraM, - typename BBlockTransferThreadClusterLengths_K0_N_K1, - typename BBlockTransferThreadClusterArrangeOrder, - typename BBlockTransferSrcAccessOrder, - ck::index_t BBlockTransferSrcVectorDim, - ck::index_t BBlockTransferSrcScalarPerVector, - ck::index_t BBlockTransferDstScalarPerVector_K1, - bool BBlockLdsAddExtraN, - index_t CShuffleMXdlPerWavePerShuffle, - index_t CShuffleNXdlPerWavePerShuffle, - typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, - index_t CBlockTransferScalarPerVector_NWaveNPerXdl, - index_t NumPrefetch = 1> -struct DeviceGemmXdl_C_Shuffle - : public DeviceGemm -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - - static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) - { - assert(K % AK1 == 0); - - const index_t K0 = K / AK1; - - const auto a_grid_desc_m_k = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); - } - }(); - - const auto a_grid_desc_k0_m_k1 = transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(K0, AK1)), make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_k0_m_k1; - } - - static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) - { - assert(K % BK1 == 0); - - const index_t K0 = K / BK1; - - const auto b_grid_desc_k_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); - } - }(); - - const auto b_grid_desc_k0_n_k1 = transform_tensor_descriptor( - b_grid_desc_k_n, - make_tuple(make_unmerge_transform(make_tuple(K0, BK1)), make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_k0_n_k1; - } - - static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) - { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); - } - } - - using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); - using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); - using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); - - // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1< - BlockSize, - ADataType, // TODO: distinguish A/B datatype - AccDataType, - CShuffleDataType, - CDataType, - InMemoryDataOperationEnum::Set, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - false, - ABlockLdsAddExtraM, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - false, - BBlockLdsAddExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, - CBlockTransferScalarPerVector_NWaveNPerXdl, - NumPrefetch>; - - // Argument - struct Argument : public BaseArgument - { - Argument(const ADataType* p_a_grid, - const BDataType* p_b_grid, - CDataType* p_c_grid, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - index_t M01, - index_t N01, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - : p_a_grid_{p_a_grid}, - p_b_grid_{p_b_grid}, - p_c_grid_{p_c_grid}, - a_grid_desc_k0_m_k1_{}, - b_grid_desc_k0_n_k1_{}, - c_grid_desc_m_n_{}, - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, - block_2_ctile_map_{}, - M01_{M01}, - N01_{N01}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - c_element_op_{c_element_op} - { - a_grid_desc_k0_m_k1_ = - DeviceGemmXdl_C_Shuffle::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); - b_grid_desc_k0_n_k1_ = - DeviceGemmXdl_C_Shuffle::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); - c_grid_desc_m_n_ = DeviceGemmXdl_C_Shuffle::MakeCGridDescriptor_M_N(M, N, StrideC); - - if(GridwiseGemm::CheckValidity( - a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) - { - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = - GridwiseGemm:: - MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( - c_grid_desc_m_n_); - - block_2_ctile_map_ = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); - } - } - - // private: - const ADataType* p_a_grid_; - const BDataType* p_b_grid_; - CDataType* p_c_grid_; - AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; - CGridDesc_M_N c_grid_desc_m_n_; - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; - index_t M01_; - index_t N01_; - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CElementwiseOperation c_element_op_; - }; - - // Invoker - struct Invoker : public BaseInvoker - { - using Argument = DeviceGemmXdl_C_Shuffle::Argument; - - float Run(const Argument& arg, int nrepeat = 1) - { - { - std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) - << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) - << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " - << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; - } - - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_)) - { - throw std::runtime_error( - "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); - } - - const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); - - const auto K = - arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); - - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K); - - float ave_time = 0; - - if(has_main_k_block_loop) - { - const auto kernel = kernel_gemm_xdlops_v3r1< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - true>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); - } - else - { - const auto kernel = kernel_gemm_xdlops_v3r1< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm:: - CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - false>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); - } - - return ave_time; - } - - // polymorphic - float Run(const BaseArgument* p_arg, int nrepeat = 1) override - { - return Run(*dynamic_cast(p_arg), nrepeat); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - static bool IsSupportedArgument(const Argument& arg) - { - return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); - } - - // polymorphic - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - return IsSupportedArgument(*dynamic_cast(p_arg)); - } - - static auto MakeArgument(const ADataType* p_a, - const BDataType* p_b, - CDataType* p_c, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - { - return Argument{p_a, - p_b, - p_c, - M, - N, - K, - StrideA, - StrideB, - StrideC, - 1, - 1, - a_element_op, - b_element_op, - c_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - // polymorphic - std::unique_ptr MakeArgumentPointer(const void* p_a, - const void* p_b, - void* p_c, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, - index_t /* KBatch */ = 1) override - { - return std::make_unique(static_cast(p_a), - static_cast(p_b), - static_cast(p_c), - M, - N, - K, - StrideA, - StrideB, - StrideC, - 1, - 1, - a_element_op, - b_element_op, - c_element_op); - } - - // polymorphic - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(Invoker{}); - } - - // polymorphic - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "DeviceGemmXdl_C_Shuffle" - << "<" - << BlockSize << ", " - << MPerBlock << ", " - << NPerBlock << ", " - << KPerBlock << ", " - << AK1 << ", " - << BK1 - << ">"; - // clang-format on - - return str.str(); - } -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp index 9cdb8009fb..70b1b0fe22 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp @@ -1,6 +1,4 @@ -#ifndef DEVICE_GEMM_XDL_C_SHUFFLE_BIAS_2D_HPP -#define DEVICE_GEMM_XDL_C_SHUFFLE_BIAS_2D_HPP - +#pragma once #include #include #include "device.hpp" @@ -291,18 +289,17 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d arg.N01_)) { throw std::runtime_error( - "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r2 has invalid setting"); } const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); - const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); - - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); float ave_time = 0; - if(has_main_k0_block_loop) + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { const auto kernel = kernel_gemm_xdlops_v3r2< GridwiseGemm, @@ -505,4 +502,3 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp index cf9804ad4b..c65ff6022a 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp @@ -303,13 +303,12 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); - const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); - - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); float ave_time = 0; - if(has_main_k0_block_loop) + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { const auto kernel = kernel_gemm_xdlops_v3r2< GridwiseGemm, diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp index 12257859c7..4a478c995d 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp @@ -345,13 +345,12 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); - const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); - - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); float ave_time = 0; - if(has_main_k0_block_loop) + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { const auto kernel = kernel_gemm_xdlops_v3r3< GridwiseGemm, diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp index c4ee7b9291..440519537e 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp @@ -465,11 +465,9 @@ struct DeviceGemm_Xdl_CShuffle const auto K = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K); - float ave_time = 0; - if(has_main_k_block_loop) + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { const auto kernel = kernel_gemm_xdl_cshuffle_v1< GridwiseGemm, diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle_v2.hpp index fdf7dc598f..7875be9dd7 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle_v2.hpp @@ -467,11 +467,9 @@ struct DeviceGemm_Xdl_CShuffle_v2 const auto K = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K); - float ave_time = 0; - if(has_main_k_block_loop) + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { const auto kernel = kernel_gemm_xdl_cshuffle_v2< GridwiseGemm, diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp index e0f70aeddb..53e3027d67 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp @@ -450,59 +450,53 @@ struct DeviceGroupedGemmXdl float Run(const Argument& arg, int nrepeat = 1) { - StaticallyIndexedArray gemm_desc_kernel_arg_arg; + StaticallyIndexedArray gemm_desc_kernel_args; - bool has_main_k0_block_loop = true; + bool has_main_k_block_loop = true; static_for<0, MaxGroupCount, 1>{}([&](auto i) { if(i < arg.gemm_desc_kernel_arg_.size()) { - gemm_desc_kernel_arg_arg(i) = arg.gemm_desc_kernel_arg_[i]; + gemm_desc_kernel_args(i) = arg.gemm_desc_kernel_arg_[i]; std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{" - << gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_.GetLength(I0) - << ", " - << gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_.GetLength(I1) - << ", " - << gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_.GetLength(I2) - << "}"; + << gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", " + << gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}"; std::cout << ", arg.b_grid_desc_k0_n_k1_{" - << gemm_desc_kernel_arg_arg[i].b_grid_desc_k0_n_k1_.GetLength(I0) - << ", " - << gemm_desc_kernel_arg_arg[i].b_grid_desc_k0_n_k1_.GetLength(I1) - << ", " - << gemm_desc_kernel_arg_arg[i].b_grid_desc_k0_n_k1_.GetLength(I2) - << "}"; + << gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", " + << gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}"; std::cout << ", arg.c_grid_desc_m_n_{ " - << gemm_desc_kernel_arg_arg[i].c_grid_desc_m_n_.GetLength(I0) << ", " - << gemm_desc_kernel_arg_arg[i].c_grid_desc_m_n_.GetLength(I1) << "}" + << gemm_desc_kernel_args[i].c_grid_desc_m_n_.GetLength(I0) << ", " + << gemm_desc_kernel_args[i].c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; - if(!GridwiseGemm::CheckValidity( - gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_, - gemm_desc_kernel_arg_arg[i].b_grid_desc_k0_n_k1_, - gemm_desc_kernel_arg_arg[i].c_grid_desc_m_n_, - arg.M01_, - arg.N01_)) + if(!GridwiseGemm::CheckValidity(gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_, + gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_, + gemm_desc_kernel_args[i].c_grid_desc_m_n_, + arg.M01_, + arg.N01_)) { throw std::runtime_error( "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); } - const auto K0 = gemm_desc_kernel_arg_arg[i].a_grid_desc_k0_m_k1_.GetLength(I0); + const auto K = gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I0) * + gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I2); - if(GridwiseGemm::CalculateHasMainK0BlockLoop(K0) != has_main_k0_block_loop) + if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop) { - throw std::runtime_error("wrong! not all gemm has_main_k0_block_loop"); + throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); } } }); float ave_time = 0; - if(has_main_k0_block_loop) + if(has_main_k_block_loop) { const auto kernel = kernel_grouped_gemm_xdlops_v2r3 + bool HasMainKBlockLoop> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -51,22 +51,22 @@ __global__ void #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_d0_grid, - p_d1_grid, - p_shared, - a_element_op, - b_element_op, - c_element_op, - d0_reduce_op, - d1_reduce_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - d_grid_desc_mblock_mperblock, - block_2_ctile_map); + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_d0_grid, + p_d1_grid, + p_shared, + a_element_op, + b_element_op, + c_element_op, + d0_reduce_op, + d1_reduce_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + d_grid_desc_mblock_mperblock, + block_2_ctile_map); #else ignore = p_a_grid; ignore = p_b_grid; @@ -154,6 +154,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 static constexpr auto AK1 = Number{}; static constexpr auto BK1 = Number{}; + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { // A matrix in LDS memory, dst of blockwise copy @@ -237,21 +241,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) return false; - // check NumGemmKPrefetchStage - if constexpr(NumGemmKPrefetchStage == 1) - { - // 1-stage prefetch always supported - } - else if constexpr(NumGemmKPrefetchStage == 2) - { - // 2-stage prefetch currently only support even number of K0 loop - // TODO: add support for odd number of K0 loop - if(!((K / KPerBlock) % 2 == 0)) - { - return false; - } - } - else + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { return false; } @@ -271,12 +264,11 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 return grid_size; } - // TODO move this function into GEMM-pipeline class - __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { - const bool has_main_k0_block_loop = ((K0 * AK1) / (NumGemmKPrefetchStage * KPerBlock)) > 1; + const index_t num_loop = K / KPerBlock; - return has_main_k0_block_loop; + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } __host__ __device__ static constexpr auto @@ -362,7 +354,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 using DefaultBlock2CTileMap = remove_cvref_t; - template + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, @@ -485,7 +477,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - NumGemmKPrefetchStage, - HasMainK0BlockLoop>{}; - const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KPerBlock); - gridwise_gemm_pipeline.Run(a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_buf, - a_block_slice_copy_step, - b_grid_desc_bk0_n_bk1, - b_block_desc_bk0_n_bk1, - b_blockwise_copy, - b_grid_buf, - b_block_buf, - b_block_slice_copy_step, - blockwise_gemm, - c_thread_buf, - num_k_block_main_loop); + GridwiseGemmPipe::template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); // shuffle C and write out { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp index e9162f6e8a..7c237ae086 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp @@ -120,6 +120,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 // K1 should be Number<...> static constexpr auto K1 = Number{}; + using ThisThreadBlock = ThisThreadBlock; + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { constexpr auto max_lds_align = K1; @@ -262,7 +264,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 }(); using BlockwiseGemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static constexpr auto K1 = Number{}; + using ThisThreadBlock = ThisThreadBlock; + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { constexpr auto max_lds_align = K1; @@ -476,7 +478,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 // sanity check auto blockwise_gemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 + bool HasMainKBlockLoop> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -48,7 +48,7 @@ __global__ void #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( + GridwiseGemm::template Run( p_a_grid, p_b_grid, p_c_grid, @@ -119,7 +119,7 @@ template < index_t CShuffleNXdlPerWavePerShuffle, typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, index_t CBlockTransferScalarPerVector_NWaveNPerXdl, - index_t NumPrefetch = 1> + index_t NumGemmKPrefetchStage = 1> struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 { static constexpr auto I0 = Number<0>{}; @@ -134,6 +134,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 // K1 should be Number<...> static constexpr auto K1 = Number{}; + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() { constexpr auto max_lds_align = K1; @@ -252,21 +256,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) return false; - // check NumPrefetch - if constexpr(NumPrefetch == 1) - { - // 1-stage prefetch always supported - } - else if constexpr(NumPrefetch == 2) - { - // 2-stage prefetch currently only support even number of K0 loop - // TODO: add support for odd number of K0 loop - if(!((K0 / K0PerBlock) % 2 == 0)) - { - return false; - } - } - else + // check gridwise gemm pipeline + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { return false; } @@ -296,12 +289,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 return grid_size; } - // TODO move this function into GEMM-pipeline class - __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { - const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1; + const index_t num_loop = K / (K0PerBlock * K1); - return has_main_k0_block_loop; + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } template @@ -379,7 +371,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 using DefaultBlock2CTileMap = remove_cvref_t; - template + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, @@ -455,7 +447,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 1, AThreadTransferSrcResetCoordinateAfterRun, true, - NumPrefetch>( + NumGemmKPrefetchStage>( a_grid_desc_k0_m_k1, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, @@ -486,7 +478,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 1, BThreadTransferSrcResetCoordinateAfterRun, true, - NumPrefetch>( + NumGemmKPrefetchStage>( b_grid_desc_k0_n_k1, make_multi_index(0, n_block_data_idx_on_grid, 0), b_element_op, @@ -503,7 +495,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 // sanity check auto blockwise_gemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - NumPrefetch, - HasMainK0BlockLoop>{}; - const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); - gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1, - a_block_desc_k0_m_k1, - a_blockwise_copy, - a_grid_buf, - a_block_buf, - a_block_slice_copy_step, - b_grid_desc_k0_n_k1, - b_block_desc_k0_n_k1, - b_blockwise_copy, - b_grid_buf, - b_block_buf, - b_block_slice_copy_step, - blockwise_gemm, - c_thread_buf, - K0BlockMainLoop); + GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); // shuffle C and write out { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp index fa6f1d1f6b..940ed9d40f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp @@ -25,7 +25,7 @@ template + bool HasMainKBlockLoop> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -52,7 +52,7 @@ __global__ void #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run( + GridwiseGemm::template Run( p_a_grid, p_b_grid, p_c_grid, @@ -128,7 +128,7 @@ template < index_t CShuffleNXdlPerWavePerShuffle, typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, index_t CBlockTransferScalarPerVector_NWaveNPerXdl, - index_t NumPrefetch = 1> + index_t NumGemmKPrefetchStage = 1> struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 { static constexpr auto I0 = Number<0>{}; @@ -143,6 +143,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 // K1 should be Number<...> static constexpr auto K1 = Number{}; + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() { constexpr auto max_lds_align = K1; @@ -261,21 +265,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) return false; - // check NumPrefetch - if constexpr(NumPrefetch == 1) - { - // 1-stage prefetch always supported - } - else if constexpr(NumPrefetch == 2) - { - // 2-stage prefetch currently only support even number of K0 loop - // TODO: add support for odd number of K0 loop - if(!((K0 / K0PerBlock) % 2 == 0)) - { - return false; - } - } - else + // check gridwise gemm pipeline + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { return false; } @@ -305,12 +298,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 return grid_size; } - // TODO move this function into GEMM-pipeline class - __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { - const bool has_main_k0_block_loop = (K0 / (NumPrefetch * K0PerBlock)) > 1; + const index_t num_loop = K / (K0PerBlock * K1); - return has_main_k0_block_loop; + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } template @@ -393,7 +385,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 using DefaultBlock2CTileMap = remove_cvref_t; - template + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, @@ -522,7 +514,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 // sanity check auto blockwise_gemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - remove_cvref_t, - NumPrefetch, - HasMainK0BlockLoop>{}; - const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); - gridwise_gemm_pipeline.Run(a_grid_desc_k0_m_k1, - a_block_desc_k0_m_k1, - a_blockwise_copy, - a_grid_buf, - a_block_buf, - a_block_slice_copy_step, - b_grid_desc_k0_n_k1, - b_block_desc_k0_n_k1, - b_blockwise_copy, - b_grid_buf, - b_block_buf, - b_block_slice_copy_step, - blockwise_gemm, - c_thread_buf, - K0BlockMainLoop); + GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); // shuffle C and write out { @@ -623,17 +597,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - make_tuple( - make_freeze_transform(I0), // freeze mblock - make_pass_through_transform( - Number{}), // M0 (MXdlPerWave) per shuffle - make_unmerge_transform( - make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl - make_freeze_transform(I0), // freeze nblock - make_pass_through_transform( - Number{}), // N0 (NXdlPerWave) per shuffle - make_unmerge_transform( - make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_tuple(make_freeze_transform(I0), // freeze mblock + make_pass_through_transform( + Number{}), // M0 (MXdlPerWave) per + // shuffle + make_unmerge_transform( + make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_freeze_transform(I0), // freeze nblock + make_pass_through_transform( + Number{}), // N0 (NXdlPerWave) per + // shuffle + make_unmerge_transform( + make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp index 791d0c2810..de97b60a62 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp @@ -1,6 +1,6 @@ #include #include "config.hpp" -#include "device_gemm_xdl_c_shuffle.hpp" +#include "device_gemm_xdl_cshuffle.hpp" #include "element_wise_operation.hpp" #include "device_operation_instance.hpp" @@ -20,26 +20,28 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + // Compilation parameters for a[m, k] * b[n, k] = c[m, n] using device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances = std::tuple< // clang-format off - //#####################| AData| BData| CData| AccData| CShuffle| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Num| - //#####################| Type| Type| Type| Type| DataType| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Prefetch| - //#####################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | - //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>, - DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>, - DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8, 2>, - DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>, - DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8, 2>, - DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8, 2>, - DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8, 2>, - DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>, - DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8, 2>, - DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8, 2>, - DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8, 2>, - DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8, 2>, - DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, F16, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8, 2> + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> // clang-format on >;