diff --git a/codegen/test/include/common.hpp b/codegen/test/include/common.hpp index 7ea0b8cc83..6873b3b436 100644 --- a/codegen/test/include/common.hpp +++ b/codegen/test/include/common.hpp @@ -1,36 +1,26 @@ #pragma once #include "ck/host/headers.hpp" -#include "ck/host/stringutils.hpp" #include #include #include #include #include -#include #include #include #include #include -// NOLINTNEXTLINE -const char* const ck_content_wrapper = R"__ck__( -${content} -)__ck__"; - -template -inline std::string content_wrapper(P p) -{ - return ck::host::InterpolateString(ck_content_wrapper, - {{"content", std::string{p.data(), p.size()}}}); -} - inline std::vector create_headers_for_test() { auto ck_headers = ck::host::GetHeaders(); std::vector result; std::transform(ck_headers.begin(), ck_headers.end(), std::back_inserter(result), [](auto& p) { - return rtc::src_file{p.first, content_wrapper(p.second)}; + std::string content; + content.reserve(p.second.size() + 1); + content.push_back(' '); // We need a whitespace before the content for hipRTC to work + content.append(p.second.data(), p.second.size()); + return rtc::src_file{p.first, std::move(content)}; }); return result; } diff --git a/codegen/test/rtc/src/compile_kernel.cpp b/codegen/test/rtc/src/compile_kernel.cpp index 144d32b7cd..c35c11b670 100644 --- a/codegen/test/rtc/src/compile_kernel.cpp +++ b/codegen/test/rtc/src/compile_kernel.cpp @@ -205,7 +205,7 @@ struct hiprtc_program } else { - headers.push_back(std::string(src.content.begin(), src.content.end())); + headers.push_back(std::move(src.content)); include_names.push_back(std::move(src.path)); } } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp index 12adbcd0fd..16c4d53d57 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -615,96 +615,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle return true; } - static constexpr bool - IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_) - { - // check vector load/store - using Row = ck::tensor_layout::gemm::RowMajor; - using Col = ck::tensor_layout::gemm::ColumnMajor; - - // check vector load of A - if constexpr(is_same_v) - { - if(KRaw_ % ABlockTransferSrcScalarPerVector != 0) - { - return false; - } - } - else if constexpr(is_same_v) - { - if(MRaw_ % ABlockTransferSrcScalarPerVector != 0) - { - return false; - } - } - else - { - return false; - } - - // check vector load of B - if constexpr(is_same_v) - { - if(NRaw_ % BBlockTransferSrcScalarPerVector != 0) - { - return false; - } - } - else if constexpr(is_same_v) - { - if(KRaw_ % BBlockTransferSrcScalarPerVector != 0) - { - return false; - } - } - else - { - return false; - } - - // check vector load of B1 - if constexpr(is_same_v) - { - if(Gemm1NRaw_ % B1BlockTransferSrcScalarPerVector != 0) - { - return false; - } - } - else if constexpr(is_same_v) - { - if(NRaw_ % B1BlockTransferSrcScalarPerVector != 0) - { - return false; - } - } - else - { - return false; - } - - // check vector load of C - if constexpr(is_same_v) - { - if(Gemm1NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) - { - return false; - } - } - else if constexpr(is_same_v) - { - if(MRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) - { - return false; - } - } - else - { - return false; - } - - return true; - } - #ifndef __HIPCC_RTC__ static bool IsSupportedArgument(const Argument& arg) { @@ -861,268 +771,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle return str.str(); } #endif - - template - struct Descriptor - { - template - static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDescriptor& a_grid_desc) - { - const auto a_grid_desc_m_k = DeviceOp::matrix_padder.PadADescriptor_M_K(a_grid_desc); - - const auto M = a_grid_desc_m_k.GetLength(I0); - const auto K = a_grid_desc_m_k.GetLength(I1); - - const auto AK0 = K / AK1; - - return transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - - template - static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const BGridDescriptor& b_grid_desc) - { - const auto b_grid_desc_n_k = DeviceOp::matrix_padder.PadBDescriptor_N_K(b_grid_desc); - - const auto N = b_grid_desc_n_k.GetLength(I0); - const auto K = b_grid_desc_n_k.GetLength(I1); - - const auto BK0 = K / BK1; - - return transform_tensor_descriptor( - b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - - template - static constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDescriptor& b1_grid_desc) - { - const auto b1_grid_desc_n_k = DeviceOp::matrix_padder.PadB1Descriptor_N_K(b1_grid_desc); - - const auto N = b1_grid_desc_n_k.GetLength(I0); - const auto K = b1_grid_desc_n_k.GetLength(I1); - - const auto B1K0 = K / B1K1; - - return transform_tensor_descriptor( - b1_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - - template - static constexpr auto MakeCGridDescriptor_M_N(const CGridDescriptor& c_grid_desc) - { - return DeviceOp::matrix_padder.PadCDescriptor_M_N(c_grid_desc); - } - - using AGridDesc_AK0_M_AK1 = - remove_cvref_t; - using BGridDesc_BK0_N_BK1 = - remove_cvref_t; - using B1GridDesc_BK0_N_BK1 = - remove_cvref_t; - using CGridDesc_M_N = remove_cvref_t; - - // GridwiseGemm - using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< - ADataType, // TODO: distinguish A/B datatype - GemmAccDataType, - CShuffleDataType, - CDataType, - AElementwiseOperation, - BElementwiseOperation, - AccElementwiseOperation, - B1ElementwiseOperation, - CElementwiseOperation, - InMemoryDataOperationEnum::Set, - AGridDesc_AK0_M_AK1, - BGridDesc_BK0_N_BK1, - B1GridDesc_BK0_N_BK1, - CGridDesc_M_N, - NumGemmKPrefetchStage, - BlockSize, - MPerBlock, - NPerBlock, - KPerBlock, - Gemm1NPerBlock, - Gemm1KPerBlock, - AK1, - BK1, - B1K1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - Gemm1NXdlPerWave, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - true, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - true, - BBlockLdsExtraN, - B1BlockTransferThreadClusterLengths_BK0_N_BK1, - B1BlockTransferThreadClusterArrangeOrder, - B1BlockTransferSrcAccessOrder, - B1BlockTransferSrcVectorDim, - B1BlockTransferSrcScalarPerVector, - B1BlockTransferDstScalarPerVector_BK1, - false, - B1BlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CShuffleBlockTransferScalarPerVector_NPerBlock, - LoopSched, - matrix_padder.PadN, - MaskOutUpperTriangle>; - - AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1; - BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1; - B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1; - CGridDesc_M_N c_grid_desc_m_n; - C0MatrixMask c0_matrix_mask; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map; - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_descriptor_mblock_mperblock_nblock_nperblock; - - // element-wise op - AElementwiseOperation a_element_op; - BElementwiseOperation b_element_op; - B1ElementwiseOperation b1_element_op; - CElementwiseOperation c_element_op; - - bool has_main_k_block_loop = true; - bool is_valid = false; - - constexpr Descriptor(ADesc a, - BDesc b, - B1Desc b1, - CDesc c, - AElementwiseOperation a_element_op_, - BElementwiseOperation b_element_op_, - B1ElementwiseOperation b1_element_op_, - CElementwiseOperation c_element_op_) - : a_grid_desc_ak0_m_ak1{MakeAGridDescriptor_AK0_M_AK1(a)}, - b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)}, - b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)}, - c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)}, - block_2_ctile_map{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n)}, - c_grid_descriptor_mblock_mperblock_nblock_nperblock{ - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n)}, - has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop( - a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))}, - c0_matrix_mask{c.GetLength(I1)}, - a_element_op{a_element_op_}, - b_element_op{b_element_op_}, - b1_element_op{b1_element_op_}, - c_element_op{c_element_op_}, - is_valid{GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - b1_grid_desc_bk0_n_bk1, - c_grid_desc_m_n, - block_2_ctile_map) and - IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1), - b_grid_desc_bk0_n_bk1.GetLength(I1), - a_grid_desc_ak0_m_ak1.GetLength(I0) * - a_grid_desc_ak0_m_ak1.GetLength(I2), - b1_grid_desc_bk0_n_bk1.GetLength(I1))} - { - } - - constexpr bool IsValid() const { return is_valid; } - }; - - template - static constexpr auto - make_descriptor(ADesc a, - BDesc b, - B1Desc b1, - CDesc c, - AElementwiseOperation a_element_op = AElementwiseOperation{}, - BElementwiseOperation b_element_op = BElementwiseOperation{}, - B1ElementwiseOperation b1_element_op = B1ElementwiseOperation{}, - CElementwiseOperation c_element_op = CElementwiseOperation{}) - { - return Descriptor( - a, b, b1, c, a_element_op, b_element_op, b1_element_op, c_element_op); - } - - template - __device__ static void Run(const Desc& desc, - const float scale, - const ADataType* __restrict__ p_a_grid, - const ADataType* __restrict__ p_b_grid, - const ADataType* __restrict__ p_b1_grid, - CDataType* __restrict__ p_c_grid) - { -#ifndef __HIPCC_RTC__ - assert(desc.is_valid); -#endif - __shared__ char p_shared_block[Desc::GridwiseGemm::GetSharedMemoryNumberOfByte()]; - AccElementwiseOperation acc_element_op{scale}; - - if(desc.has_main_k_block_loop) - { - Desc::GridwiseGemm::template Run( - p_a_grid, - p_b_grid, - p_b1_grid, - p_c_grid, - p_shared_block, - desc.a_element_op, - desc.b_element_op, - acc_element_op, - desc.b1_element_op, - desc.c_element_op, - desc.a_grid_desc_ak0_m_ak1, - desc.b_grid_desc_bk0_n_bk1, - desc.b1_grid_desc_bk0_n_bk1, - desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock, - desc.block_2_ctile_map, - desc.c0_matrix_mask); - } - else - { - Desc::GridwiseGemm::template Run( - p_a_grid, - p_b_grid, - p_b1_grid, - p_c_grid, - p_shared_block, - desc.a_element_op, - desc.b_element_op, - acc_element_op, - desc.b1_element_op, - desc.c_element_op, - desc.a_grid_desc_ak0_m_ak1, - desc.b_grid_desc_bk0_n_bk1, - desc.b1_grid_desc_bk0_n_bk1, - desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock, - desc.block_2_ctile_map, - desc.c0_matrix_mask); - } - } }; } // namespace device