From e10cbe9ee247fa2b2cfe6ed86fde19eb63f9c6f0 Mon Sep 17 00:00:00 2001 From: Alan Turner Date: Thu, 21 Sep 2023 16:40:41 +0000 Subject: [PATCH] Add constexpr IsSupported --- ...batched_gemm_softmax_gemm_xdl_cshuffle.hpp | 169 ++++++++++++------ 1 file changed, 111 insertions(+), 58 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp index 8f16838149..e0a878a903 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -611,6 +611,99 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle return true; } + static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_) + { + // check vector load/store + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + // check vector load of A + if constexpr(is_same_v) + { + if(KRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v) + { + // FIXME: not rigorous + if(MRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector load of B + if constexpr(is_same_v) + { + if(NRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v) + { + // FIXME: not rigorous + if(KRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector load of B1 + if constexpr(is_same_v) + { + if(Gemm1NRaw_ % B1BlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v) + { + // FIXME: not rigorous + if(NRaw_ % B1BlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector load of C + if constexpr(is_same_v) + { + if(Gemm1NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else if constexpr(is_same_v) + { + // FIXME: not rigorous + if(MRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + + return true; + } + static bool IsSupportedArgument(const Argument& arg) { if(!ck::is_xdl_supported()) @@ -625,52 +718,12 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle const auto KRaw = arg.raw_lengths_m_n_k_o_[2]; const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3]; - // Check scalar per vector requirement - const auto a_extent_lowest = - is_same_v ? KRaw : MRaw; - const auto b_extent_lowest = - is_same_v ? NRaw : KRaw; - const auto b1_extent_lowest = - is_same_v ? Gemm1NRaw : NRaw; - const auto c_extent_lowest = - is_same_v ? Gemm1NRaw : MRaw; - - if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && - b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 && - b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && - c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) - { - return false; - } - return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); - } - - static constexpr bool IsSupported(index_t MRaw, index_t NRaw, index_t KRaw, index_t Gemm1NRaw) - { - // Check scalar per vector requirement - const auto a_extent_lowest = - is_same_v ? KRaw : MRaw; - const auto b_extent_lowest = - is_same_v ? NRaw : KRaw; - const auto b1_extent_lowest = - is_same_v ? Gemm1NRaw : NRaw; - const auto c_extent_lowest = - is_same_v ? Gemm1NRaw : MRaw; - - if (!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && - b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 && - b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && - c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) - { - return false; - } - - return true; + arg.block_2_ctile_map_) and + IsSupported(MRaw, NRaw, KRaw, Gemm1NRaw); } // polymorphic @@ -861,7 +914,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle remove_cvref_t; // GridwiseGemm - using GridwiseGemmSpec = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< + using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, CShuffleDataType, @@ -928,8 +981,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1; CGridDesc_M_N c_grid_desc_m_n; C0MatrixMask c0_matrix_mask; - typename GridwiseGemmSpec::DefaultBlock2CTileMap block_2_ctile_map; - typename GridwiseGemmSpec::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_descriptor_mblock_mperblock_nblock_nperblock; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_descriptor_mblock_mperblock_nblock_nperblock; // element-wise op AElementwiseOperation a_element_op; @@ -952,23 +1005,27 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)}, b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)}, c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)}, - block_2_ctile_map{GridwiseGemmSpec::MakeDefaultBlock2CTileMap( + block_2_ctile_map{GridwiseGemm::MakeDefaultBlock2CTileMap( c_grid_desc_m_n)}, c_grid_descriptor_mblock_mperblock_nblock_nperblock{ - GridwiseGemmSpec::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n)}, - has_main_k_block_loop{GridwiseGemmSpec::CalculateHasMainKBlockLoop( + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n)}, + has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop( a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))}, c0_matrix_mask{c.GetLength(I1)}, a_element_op{a_element_op_}, b_element_op{b_element_op_}, b1_element_op{b1_element_op_}, c_element_op{c_element_op_}, - is_valid{GridwiseGemmSpec::CheckValidity( + is_valid{GridwiseGemm::CheckValidity( a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1, c_grid_desc_m_n, - block_2_ctile_map)} + block_2_ctile_map) and + IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1), + b_grid_desc_bk0_n_bk1.GetLength(I1), + a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2), + b1_grid_desc_bk0_n_bk1.GetLength(I1))} { } @@ -1001,17 +1058,13 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle const ADataType* __restrict__ p_b1_grid, CDataType* __restrict__ p_c_grid) { - assert(desc.is_valid and - IsSupported(desc.a_grid_desc_ak0_m_ak1.GetLength(I1), - desc.b_grid_desc_bk0_n_bk1.GetLength(I1), - desc.a_grid_desc_ak0_m_ak1.GetLength(I0) * desc.a_grid_desc_ak0_m_ak1.GetLength(I2), - desc.b1_grid_desc_bk0_n_bk1.GetLength(I1))); - __shared__ char p_shared_block[Desc::GridwiseGemmSpec::GetSharedMemoryNumberOfByte()]; + assert(desc.is_valid); + __shared__ char p_shared_block[Desc::GridwiseGemm::GetSharedMemoryNumberOfByte()]; AccElementwiseOperation acc_element_op{scale}; if(desc.has_main_k_block_loop) { - Desc::GridwiseGemmSpec::template Run(p_a_grid, + Desc::GridwiseGemm::template Run(p_a_grid, p_b_grid, p_b1_grid, p_c_grid, @@ -1030,7 +1083,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle } else { - Desc::GridwiseGemmSpec::template Run(p_a_grid, + Desc::GridwiseGemm::template Run(p_a_grid, p_b_grid, p_b1_grid, p_c_grid,