diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp index 4b1106c122..1e03405536 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp @@ -1,4 +1,8 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + #pragma once + #include #include diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp index 776f96e8e6..89bfc180a5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp @@ -73,6 +73,11 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK{}; static constexpr auto I3 = Number<3>{}; + // TODO: should be exposed as Tparams. + static constexpr index_t NumGemmKPrefetchStage = 1; + static constexpr LoopScheduler LoopSched = make_default_loop_scheduler(); + static constexpr PipelineVersion PipelineVer = PipelineVersion::v2; + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, // TODO: distinguish A/B datatype @@ -85,6 +90,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK; + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + LoopSched, + PipelineVer>; using Argument = typename GridwiseGemm::Argument; using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; @@ -257,7 +265,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK; + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + LoopSched, + PipelineVersion::v2>; using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; using Block2ETileMapKSplit = @@ -265,8 +268,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK(arg.gemm_kernel_args_.size()) + arg.skipped_group_count_) != arg.group_count_) { +#if DEBUG_LOG + std::cout << "The group count is not equal to sum of skipped groups " + "and kernel args size!" + << std::endl; +#endif // DEBUG_LOG return false; } @@ -509,14 +516,15 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK + typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + LoopScheduler LoopSched = make_default_loop_scheduler(), + PipelineVersion PipelineVer = PipelineVersion::v1> struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { static constexpr auto I0 = Number<0>{}; @@ -99,8 +102,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 static constexpr auto M01 = 1; static constexpr auto N01 = 1; + static constexpr auto gemm_padder = + tensor_operation::device::GemmPadder{ + MPerBlock, NPerBlock, K1* K0PerBlock}; + using ThisThreadBlock = ThisThreadBlock; + using GridwiseGemmPipe = remove_cvref_t())>; + struct Argument : public ck::tensor_operation::device::BaseArgument { const FloatAB* p_a_grid; @@ -176,12 +186,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 // prefer this to be called on host __host__ __device__ static auto CalculateMPadded(index_t M) { - return (M + MPerBlock - 1) / MPerBlock * MPerBlock; + return math::integer_least_multiple(M, MPerBlock); } __host__ __device__ static auto CalculateNPadded(index_t N) { - return (N + NPerBlock - 1) / NPerBlock * NPerBlock; + return math::integer_least_multiple(N, NPerBlock); } __host__ __device__ static auto CalculateK0(index_t K, index_t K_Batch = 1) @@ -295,8 +305,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 } } - __host__ __device__ static auto - MakeCGridDescriptor_M_N(index_t M, index_t N, index_t MPad, index_t NPad, index_t StrideC) + __host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) { const auto c_grid_desc_m_n = [&]() { if constexpr(is_same::value) @@ -309,22 +318,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 } }(); - if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) - { - return transform_tensor_descriptor(c_grid_desc_m_n, - make_tuple(make_right_pad_transform(M, MPad - M), - make_right_pad_transform(N, NPad - N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else - { - return transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } + return gemm_padder.PadCDescriptor_M_N(c_grid_desc_m_n); } __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() @@ -383,7 +377,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) { if(!(karg.M % MPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG return false; + } } if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || @@ -391,40 +393,116 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) { if(!(karg.N % NPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG return false; + } } if constexpr(is_same::value) { if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG return false; + } } else { if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG return false; + } } if constexpr(is_same::value) { if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG return false; + } } else { if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG return false; + } } if constexpr(is_same::value) { if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) + { +#if DEBUG_LOG + std::cout + << "Arg N (" << karg.N + << ") value is not a multiple of CBlockTransferScalarPerVector_NWaveNPerXDL (" + << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG return false; + } } else { if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) + { +#if DEBUG_LOG + std::cout + << "Arg M (" << karg.M + << ") value is not a multiple of CBlockTransferScalarPerVector_NWaveNPerXDL (" + << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG return false; + } + } + + const auto num_k_loop = karg.K0 / K0PerBlock; + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { +#if DEBUG_LOG + std::cout << "The number of k loops (" << num_k_loop + << ") value is not supported by GridwiseGemm Pipeline." + << " K0: " << karg.K0 << ", K0PerBlock: " << K0PerBlock << " " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; +#endif // DEBUG_LOG + return false; } return true; @@ -439,9 +517,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) { - const bool has_main_k0_block_loop = K0 > K0PerBlock; - - return has_main_k0_block_loop; + const index_t num_loop = K0 / K0PerBlock; + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } template @@ -490,7 +567,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 return BlockToCTileMap_3DGrid_KSplit(); } - using CGridDesc_M_N = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; using DefaultBlock2CTileMap = remove_cvref_t; template {}; -#else - auto blockwise_gemm = BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1< + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, FloatAB, FloatAcc, @@ -703,9 +767,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 NPerXDL, MRepeat, NRepeat, - K1>{}; - -#endif + K1, + LoopSched>(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); @@ -761,7 +824,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); k0_block_data_begin += K0PerBlock; - } while(k0_block_data_begin < (K0 - K0PerBlock)); + } while(k0_block_data_begin < (karg.K0 - K0PerBlock)); } // tail @@ -772,13 +835,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 } #else // gridwise GEMM pipeline - const auto gridwise_gemm_pipeline = - GridwiseGemmPipeline_Selector(); - const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( (a_b_k0_m_k1_grid_desc.GetLength(I1) * a_b_k0_m_k1_grid_desc.GetLength(I3)) / (K0PerBlock * K1)); + const auto gridwise_gemm_pipeline = GridwiseGemmPipe{}; + gridwise_gemm_pipeline.template Run(a_b_k0_m_k1_grid_desc, a_b_k0_m_k1_block_desc, a_blockwise_copy, @@ -993,24 +1055,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 } } - template - struct LStr - { - static std::string Get() { return ""; } - }; - - template <> - struct LStr - { - static std::string Get() { return "R"; } - }; - - template <> - struct LStr - { - static std::string Get() { return "C"; } - }; - static std::string GetTypeString() { auto str = std::stringstream(); diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp index a93cb7fc84..5f5d6c9b5a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp @@ -64,6 +64,7 @@ using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_irregular_tile_instances = st //###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp index 0385b0fc0c..a3d73440eb 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp @@ -44,14 +44,14 @@ using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_tile_instanc DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, -// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, -// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, -// DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp index 5933ff61ec..dddfa2aa44 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp @@ -37,7 +37,7 @@ using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instanc //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, -// DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, @@ -45,7 +45,7 @@ using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instanc DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, -// DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 32, 8, 8, 32, 32, 1, 4, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, diff --git a/profiler/include/profiler/profile_gemm_splitk_impl.hpp b/profiler/include/profiler/profile_gemm_splitk_impl.hpp index 4cc62509d7..ab1bce258a 100644 --- a/profiler/include/profiler/profile_gemm_splitk_impl.hpp +++ b/profiler/include/profiler/profile_gemm_splitk_impl.hpp @@ -246,9 +246,9 @@ bool profile_gemm_splitk_impl(int do_verification, } std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA - << " StrideB = " << StrideB << " StrideC = " << StrideC << " : " << best_ave_time - << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " - << best_op_name << std::endl; + << " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << KBatch + << " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec + << " GB/s, " << best_op_name << std::endl; return pass; } diff --git a/profiler/include/profiler/profile_grouped_gemm_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_impl.hpp index 23dca244dc..9abb5e7a53 100644 --- a/profiler/include/profiler/profile_grouped_gemm_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_impl.hpp @@ -19,6 +19,7 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" +#include "ck/library/utility/fill.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" namespace ck { @@ -43,7 +44,6 @@ bool profile_grouped_gemm_impl(int do_verification, const std::vector& StrideCs, int kbatch = 1) { - bool pass = true; auto f_host_tensor_descriptor = @@ -81,11 +81,11 @@ bool profile_grouped_gemm_impl(int do_verification, c_m_n_device_results.push_back( Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); - +#if DEBUG_LOG std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i << "]:" << c_m_n_device_results[i].mDesc << std::endl; - +#endif // DEBUG_LOG std::size_t num_thread = 1; switch(init_method) { @@ -191,65 +191,71 @@ bool profile_grouped_gemm_impl(int do_verification, DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get())); gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer()); + std::string gemm_name = gemm_ptr->GetTypeString(); + + if(kbatch > 1) + { + using DeviceOpSplitK = + ck::tensor_operation::device::DeviceGroupedGemmSplitK, + CLayout, + ADataType, + BDataType, + ck::Tuple<>, + CDataType, + AElementOp, + BElementOp, + CElementOp>; + + if(dynamic_cast(gemm_ptr.get()) != nullptr) + { + dynamic_cast(gemm_ptr.get()) + ->SetKBatchSize(argument_ptr.get(), kbatch); + } + } if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) { - std::string gemm_name = gemm_ptr->GetTypeString(); - - if(kbatch > 1) - { - using DeviceOpSplitK = - ck::tensor_operation::device::DeviceGroupedGemmSplitK, - CLayout, - ADataType, - BDataType, - ck::Tuple<>, - CDataType, - AElementOp, - BElementOp, - CElementOp>; - - if(dynamic_cast(gemm_ptr.get()) != nullptr) - { - dynamic_cast(gemm_ptr.get()) - ->SetKBatchSize(argument_ptr.get(), kbatch); - } - } float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); - std::size_t flop = 0, num_btype = 0; - for(std::size_t i = 0; i < gemm_descs.size(); i++) + if(time_kernel) { - flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i]; + std::size_t flop = 0, num_btype = 0; + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i]; - num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + sizeof(BDataType) * Ks[i] * Ns[i] + - sizeof(CDataType) * Ms[i] * Ns[i]; - } + num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + + sizeof(BDataType) * Ks[i] * Ns[i] + + sizeof(CDataType) * Ms[i] * Ns[i]; + } - float tflops = static_cast(flop) / 1.E9 / ave_time; + float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " - << gb_per_sec << " GB/s, " << gemm_name << std::endl; + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << std::endl; - if(tflops > best_tflops) - { - best_gemm_name = gemm_name; - best_tflops = tflops; - best_ave_time = ave_time; - best_gb_per_sec = gb_per_sec; + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } } if(do_verification) { + bool instance_pass = true; for(std::size_t i = 0; i < gemm_descs.size(); i++) { c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data()); + c_device_buf[i]->SetZero(); Tensor c_m_n_host_result( f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})); @@ -274,7 +280,20 @@ bool profile_grouped_gemm_impl(int do_verification, c_element_op); ref_invoker.Run(ref_argument); - pass = pass && ck::utils::check_err(c_m_n_device_results[i], c_m_n_host_result); + if(std::is_same_v && kbatch > 1) + { + instance_pass = + instance_pass && ck::utils::check_err(c_m_n_device_results[i], + c_m_n_host_result, + "Error: Incorrect results!", + 0.06); + } + else + { + instance_pass = + instance_pass && + ck::utils::check_err(c_m_n_device_results[i], c_m_n_host_result); + } if(do_log) { @@ -289,16 +308,25 @@ bool profile_grouped_gemm_impl(int do_verification, << std::endl; } } + + std::cout << "Instance: " << gemm_name << " verification " + << (instance_pass ? "SUCCEED" : "FAILED") << std::endl; + + pass = pass && instance_pass; } } else { - std::cout << "does not support this GEMM problem" << std::endl; + std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem" + << std::endl; } } - std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " - << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; + if(time_kernel) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; + } return pass; } diff --git a/test/gemm_split_k/CMakeLists.txt b/test/gemm_split_k/CMakeLists.txt index 09bbf79389..2274854f88 100644 --- a/test/gemm_split_k/CMakeLists.txt +++ b/test/gemm_split_k/CMakeLists.txt @@ -1,5 +1,4 @@ if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") - add_test_executable(test_gemm_split_k gemm_split_k.cpp) - target_link_libraries(test_gemm_split_k PRIVATE utility) - target_link_libraries(test_gemm_split_k PRIVATE device_gemm_splitk_instance) + add_gtest_executable(test_gemm_splitk test_gemm_splitk.cpp) + target_link_libraries(test_gemm_splitk PRIVATE utility device_gemm_splitk_instance) endif() diff --git a/test/gemm_split_k/gemm_split_k.cpp b/test/gemm_split_k/gemm_split_k.cpp deleted file mode 100644 index 1edb5769c6..0000000000 --- a/test/gemm_split_k/gemm_split_k.cpp +++ /dev/null @@ -1,261 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp" - -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" - -#include "ck/library/utility/host_gemm.hpp" - -enum struct GemmMatrixLayout -{ - MK_KN_MN, // 0 - MK_NK_MN, // 1 - KM_KN_MN, // 2 - KM_NK_MN, // 3 -}; - -template -static bool check_out(const Tensor& ref, const Tensor& result) -{ - float max_diff = 1e-6; - - for(std::size_t i = 0; i < ref.mData.size(); ++i) - { - float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); - if(max_diff < diff) - { - return false; - } - } - - return true; -} - -struct gemmArgs -{ - GemmMatrixLayout layout; - int M; - int N; - int K; - int StrideA; - int StrideB; - int StrideC; - int KBatch; -}; - -int test_gemm(const gemmArgs& args) -{ - using Row = ck::tensor_layout::gemm::RowMajor; - using Col = ck::tensor_layout::gemm::ColumnMajor; - - using PassThrough = ck::tensor_operation::element_wise::PassThrough; - - bool a_row_major, b_row_major, c_row_major; - - switch(args.layout) - { - case GemmMatrixLayout::MK_KN_MN: - a_row_major = true; - b_row_major = true; - c_row_major = true; - break; - case GemmMatrixLayout::MK_NK_MN: - a_row_major = true; - b_row_major = false; - c_row_major = true; - break; - case GemmMatrixLayout::KM_KN_MN: - a_row_major = false; - b_row_major = true; - c_row_major = true; - break; - case GemmMatrixLayout::KM_NK_MN: - a_row_major = false; - b_row_major = false; - c_row_major = true; - break; - default: printf("not supported layout"); return 1; - } - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, bool row_major) { - using namespace ck::literals; - - if(row_major) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - Tensor a_m_k(f_host_tensor_descriptor(args.M, args.K, args.StrideA, a_row_major)); - Tensor b_k_n(f_host_tensor_descriptor(args.K, args.N, args.StrideB, b_row_major)); - Tensor c_m_n_host_result( - f_host_tensor_descriptor(args.M, args.N, args.StrideC, c_row_major)); - Tensor c_m_n_device_result( - f_host_tensor_descriptor(args.M, args.N, args.StrideC, c_row_major)); - - // init data - std::size_t num_thread = 1; - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - // set zero to c_device_buf - c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0{}, num_thread); - - host_gemm_mk_kn_mn(a_m_k, - b_k_n, - c_m_n_host_result, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}); - - DeviceMem a_device_buf(sizeof(float) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(float) * b_k_n.mDesc.GetElementSpaceSize()); - DeviceMem c_device_buf(sizeof(float) * c_m_n_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_m_k.mData.data()); - b_device_buf.ToDevice(b_k_n.mData.data()); - c_device_buf.ToDevice(c_m_n_device_result.mData.data()); - - auto test = [&](auto a_layout, auto b_layout, auto c_layout) { - bool success = false; - - using DeviceOp = ck::tensor_operation::device::DeviceGemmSplitK; - - const auto gemm_ptrs = - ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< - DeviceOp>::GetInstances(); - - for(auto& gemm_ptr : gemm_ptrs) - { - auto argument_ptr = - gemm_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(c_device_buf.GetDeviceBuffer()), - args.M, - args.N, - args.K, - args.StrideA, - args.StrideB, - args.StrideC, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}, - args.KBatch); - - auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); - - if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) - { - invoker_ptr->Run(argument_ptr.get()); - - c_device_buf.FromDevice(c_m_n_device_result.mData.data()); - - if(!check_out(c_m_n_host_result, c_m_n_device_result)) - { - success = false; - break; - } - success = true; - } - } - - return success; - }; - - bool success = false; - - if(args.layout == GemmMatrixLayout::MK_KN_MN) - { - success = test(Row{}, Row{}, Row{}); - } - else if(args.layout == GemmMatrixLayout::MK_NK_MN) - { - success = test(Row{}, Col{}, Row{}); - } - else if(args.layout == GemmMatrixLayout::KM_KN_MN) - { - success = test(Col{}, Row{}, Row{}); - } - else - { - success = test(Col{}, Col{}, Row{}); - } - - auto error_code = 0; - if(success) - { - std::cout << "test split k : Pass" << std::endl; - } - else - { - std::cout << "test split k: Fail " << std::endl; - error_code = -1; // test needs to report failure - } - return error_code; -} - -int main(int argc, char* argv[]) -{ - std::vector test_cases; - if(argc == 1) - { - test_cases = {{GemmMatrixLayout::MK_KN_MN, 1024, 1024, 1024, 1024, 1024, 1024, 2}, - {GemmMatrixLayout::MK_KN_MN, 1024, 1024, 1024, 1024, 1024, 1024, 8}}; - } - else if(argc == 9) - { - const auto layout = static_cast(std::stoi(argv[1])); - - const int M = std::stoi(argv[2]); - const int N = std::stoi(argv[3]); - const int K = std::stoi(argv[4]); - - const int StrideA = std::stoi(argv[5]); - const int StrideB = std::stoi(argv[6]); - const int StrideC = std::stoi(argv[7]); - const int KBatch = std::stoi(argv[8]); - test_cases = {{layout, M, N, K, StrideA, StrideB, StrideC, KBatch}}; - } - else - { - printf("arg1: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); - printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); - printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); - printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); - printf("arg2 to 7: M, N, K, StrideA, StrideB, StrideC KBatch\n"); - return -1; - } - bool error = false; - for(const auto& kinder : test_cases) - { - error |= test_gemm(kinder); - } - return error ? 1 : 0; -} diff --git a/test/gemm_split_k/test_gemm_splitk.cpp b/test/gemm_split_k/test_gemm_splitk.cpp new file mode 100644 index 0000000000..9eba5bba37 --- /dev/null +++ b/test/gemm_split_k/test_gemm_splitk.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_splitk_util.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmSplitK_MK_KN + : public ck::test::TestGemmSplitK, Tuple>::type> +{ +}; + +template +class TestGemmSplitK_MK_NK + : public ck::test::TestGemmSplitK, Tuple>::type> +{ +}; + +template +class TestGemmSplitK_KM_KN + : public ck::test::TestGemmSplitK, Tuple>::type> +{ +}; + +template +class TestGemmSplitK_KM_NK + : public ck::test::TestGemmSplitK, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes = ::testing::Types< + // ADataType, BDataType, CDataType + std::tuple< F16, F16, F16>, + std::tuple< F32, F32, F32> + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmSplitK_MK_KN, KernelTypes); +TYPED_TEST_SUITE(TestGemmSplitK_MK_NK, KernelTypes); +TYPED_TEST_SUITE(TestGemmSplitK_KM_KN, KernelTypes); +TYPED_TEST_SUITE(TestGemmSplitK_KM_NK, KernelTypes); + +#include "test_gemm_splitk_ut_cases.inc" diff --git a/test/gemm_split_k/test_gemm_splitk_ut_cases.inc b/test/gemm_split_k/test_gemm_splitk_ut_cases.inc new file mode 100644 index 0000000000..54b9c6c9e3 --- /dev/null +++ b/test/gemm_split_k/test_gemm_splitk_ut_cases.inc @@ -0,0 +1,217 @@ +#pragma once + +TYPED_TEST(TestGemmSplitK_MK_KN, SmallM) +{ + std::vector Ms{0, 1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_MK_NK, SmallM) +{ + std::vector Ms{0, 1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_KM_KN, SmallM) +{ + std::vector Ms{0, 1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, M, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_KM_NK, SmallM) +{ + std::vector Ms{0, 1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, M, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_MK_KN, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_MK_NK, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_KM_KN, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, M, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_KM_NK, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, M, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_MK_KN, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_MK_NK, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_KM_KN, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, M, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_KM_NK, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, M, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_MK_KN, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_MK_NK, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_KM_KN, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, M, StrideB, StrideC); +} + +TYPED_TEST(TestGemmSplitK_KM_NK, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, M, StrideB, StrideC); +} diff --git a/test/gemm_split_k/test_gemm_splitk_util.hpp b/test/gemm_split_k/test_gemm_splitk_util.hpp new file mode 100644 index 0000000000..8243747a69 --- /dev/null +++ b/test/gemm_split_k/test_gemm_splitk_util.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "include/ck/utility/data_type.hpp" +#include "profiler/profile_gemm_splitk_impl.hpp" + +namespace ck { +namespace test { + +template +class TestGemmSplitK : public testing::Test +{ + using Row = ck::tensor_layout::gemm::RowMajor; + using F32 = float; + + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = Row; + using ADataType = std::tuple_element_t<2, Tuple>; + using BDataType = std::tuple_element_t<3, Tuple>; + using CDataType = std::tuple_element_t<4, Tuple>; + + public: + static constexpr bool verify_ = true; + static constexpr int init_method_ = 1; // decimal value initialization + static constexpr bool log_ = false; + static constexpr bool bench_ = false; // measure kernel performance + std::vector k_batches_; + + void SetUp() override { k_batches_ = {1, 2, 3, 5, 8}; } + + void Run(const int M, + const int N, + const int K, + const int StrideA, + const int StrideB, + const int StrideC) + { + for(auto kb : k_batches_) + { + RunSingle(M, N, K, StrideA, StrideB, StrideC, kb); + } + } + + void RunSingle(const int M, + const int N, + const int K, + const int StrideA, + const int StrideB, + const int StrideC, + int kbatch = 1) + { + bool pass = ck::profiler::profile_gemm_splitk_impl( + verify_, init_method_, log_, bench_, M, N, K, StrideA, StrideB, StrideC, kbatch); + EXPECT_TRUE(pass); + } +}; + +} // namespace test +} // namespace ck diff --git a/test/grouped_gemm/CMakeLists.txt b/test/grouped_gemm/CMakeLists.txt index a7619eac6e..40f634d8b3 100644 --- a/test/grouped_gemm/CMakeLists.txt +++ b/test/grouped_gemm/CMakeLists.txt @@ -1,5 +1,9 @@ if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") - add_test_executable(test_grouped_gemm_fp16 grouped_gemm_fp16.cpp) - target_link_libraries(test_grouped_gemm_fp16 PRIVATE utility) - target_link_libraries(test_grouped_gemm_fp16 PRIVATE device_grouped_gemm_instance) + add_custom_target(test_grouped_gemm) + add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk.cpp) + add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface.cpp) + target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance) + target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance) + + add_dependencies(test_grouped_gemm test_grouped_gemm_splitk test_grouped_gemm_interface) endif() diff --git a/test/grouped_gemm/grouped_gemm_fp16.cpp b/test/grouped_gemm/grouped_gemm_fp16.cpp deleted file mode 100644 index f20d750d36..0000000000 --- a/test/grouped_gemm/grouped_gemm_fp16.cpp +++ /dev/null @@ -1,69 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include - -#include "profiler/profile_grouped_gemm_impl.hpp" - -namespace { - -using ADataType = ck::half_t; -using BDataType = ck::half_t; -using CDataType = ck::half_t; -using AccDataType = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -bool TestGroupedGemm() -{ - - std::mt19937 gen(19391); - std::uniform_int_distribution<> distrib(1, 10); - int group_count = distrib(gen); - - // GEMM shape - std::vector gemm_descs; - std::vector p_a, p_b; - std::vector p_c; - - std::vector Ms, Ns, Ks, StrideAs, StrideBs, StrideCs; - - for(int i = 0; i < group_count; i++) - { - Ms.push_back(256 + 256 * distrib(gen)); - Ns.push_back(256 + 256 * distrib(gen)); - Ks.push_back(128 + 128 * distrib(gen)); - - StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); - StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); - StrideCs.push_back(std::is_same::value ? Ns[i] : Ms[i]); - } - - return ck::profiler::profile_grouped_gemm_impl( - true, 1, false, 1, Ms, Ns, Ks, StrideAs, StrideBs, StrideCs); -} - -} // anonymous namespace - -int main() -{ - bool res = true; - - res = res && TestGroupedGemm(); - res = res && TestGroupedGemm(); - res = res && TestGroupedGemm(); - res = res && TestGroupedGemm(); - - std::cout << "TestGroupedGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - - return res ? 0 : 1; -} diff --git a/test/grouped_gemm/test_grouped_gemm_interface.cpp b/test/grouped_gemm/test_grouped_gemm_interface.cpp new file mode 100644 index 0000000000..ffa8840fc7 --- /dev/null +++ b/test/grouped_gemm/test_grouped_gemm_interface.cpp @@ -0,0 +1,202 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include "gtest/gtest.h" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "test_grouped_gemm_util.hpp" + +class TestGGemmSplitKInterface_MKNKMN : public ::testing::Test +{ + protected: + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + using ALayout = Row; + using BLayout = Col; + using ELayout = Row; + + static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + + template + using GGemmInstance = + ck::test::DeviceGroupedGemmSplitkInstanceWrapper; + + using DefaultGGemmInstance = GGemmInstance; +}; + +TEST_F(TestGGemmSplitKInterface_MKNKMN, TileSize) +{ + std::vector Ms{128, 256, 188, 512}; + constexpr int N = 256; + constexpr int K = 128; + + std::vector Ns(Ms.size(), N); + std::vector Ks(Ms.size(), K); + std::vector StrideAs(Ms.size(), K); + std::vector StrideBs(Ms.size(), K); + std::vector StrideCs(Ms.size(), N); + + // M % MPerBlock + EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); + + Ms = std::vector{256, 128, 128, 512}; + Ns = std::vector{256, 177, 128, 512}; + // N % NPerBlock + EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); +} + +TEST_F(TestGGemmSplitKInterface_MKNKMN, VectorLoadWidth) +{ + static constexpr auto GemmMNKPadding = + ck::tensor_operation::device::GemmSpecialization::MNKPadding; + using PaddedGGemmInstance = GGemmInstance; + + std::vector Ms{128, 256, 256, 512}; + constexpr int N = 256; + constexpr int K = 512; + + std::vector Ns(Ms.size(), N); + std::vector Ks(Ms.size(), K); + std::vector StrideAs(Ms.size(), K); + std::vector StrideBs(Ms.size(), K); + std::vector StrideCs(Ms.size(), N); + + // K % ABlockTransferSrcScalarPerVector + Ks = std::vector{256, 177, 128, 512}; + EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); + + Ks = std::vector{256, 164, 128, 512}; + // K % BBlockTransferSrcScalarPerVector + EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); + + Ks = std::vector(4, 128); + Ns = std::vector{256, 127, 128, 512}; + // N % CBlockTransferScalarPerVector_NWaveNPerXDL + EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); +} + +TEST_F(TestGGemmSplitKInterface_MKNKMN, KLoops) +{ + std::vector Ms{128, 256, 256, 512}; + constexpr int N = 256; + constexpr int K = 128; + constexpr int kbatch = 4; + + std::vector Ns(Ms.size(), N); + std::vector Ks(Ms.size(), K); + std::vector StrideAs(Ms.size(), K); + std::vector StrideBs(Ms.size(), K); + std::vector StrideCs(Ms.size(), N); + + // kloops % 2 + Ks = std::vector{256, 512, 320, 768}; + EXPECT_FALSE( + DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch)); + + // Not all gemms have same value for main_k0_block_loop! + Ks = std::vector{256, 512, 512, 512}; + EXPECT_THROW(DefaultGGemmInstance{}.Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch), + std::runtime_error); +} + +class TestGGemmSplitKInterface_KMKNNM : public ::testing::Test +{ + protected: + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + using ALayout = Col; + using BLayout = Row; + using ELayout = Col; + + static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + + template + using GGemmInstance = + ck::test::DeviceGroupedGemmSplitkInstanceWrapper; + + using DefaultGGemmInstance = GGemmInstance; +}; + +TEST_F(TestGGemmSplitKInterface_KMKNNM, TileSize) +{ + std::vector Ms{128, 256, 188, 512}; + constexpr int N = 256; + constexpr int K = 128; + + std::vector Ns(Ms.size(), N); + std::vector Ks(Ms.size(), K); + std::vector StrideAs(Ms.size(), K); + std::vector StrideBs(Ms.size(), K); + std::vector StrideCs(Ms.size(), N); + + // M % MPerBlock + EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); + + Ms = std::vector{128, 256, 256, 512}; + Ns = std::vector{256, 177, 128, 512}; + // N % NPerBlock + EXPECT_FALSE(DefaultGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); +} + +TEST_F(TestGGemmSplitKInterface_KMKNNM, VectorLoadWidth) +{ + static constexpr auto GemmMNKPadding = + ck::tensor_operation::device::GemmSpecialization::MNKPadding; + using PaddedGGemmInstance = GGemmInstance; + + std::vector Ms{128, 256, 256, 512}; + constexpr int N = 256; + constexpr int K = 512; + + std::vector Ns(Ms.size(), N); + std::vector Ks(Ms.size(), K); + std::vector StrideAs(Ms.size(), K); + std::vector StrideBs(Ms.size(), K); + std::vector StrideCs(Ms.size(), N); + + // M % ABlockTransferSrcScalarPerVector + Ms = std::vector{256, 177, 128, 512}; + EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); + + Ms = std::vector{128, 256, 256, 512}; + Ns = std::vector{256, 164, 128, 512}; + // N % BBlockTransferSrcScalarPerVector + EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); + + Ns = std::vector{128, 256, 256, 512}; + Ms = std::vector{256, 130, 128, 512}; + // M % CBlockTransferScalarPerVector_NWaveNPerXDL + EXPECT_FALSE(PaddedGGemmInstance{}.IsSupported(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs)); +} diff --git a/test/grouped_gemm/test_grouped_gemm_splitk.cpp b/test/grouped_gemm/test_grouped_gemm_splitk.cpp new file mode 100644 index 0000000000..d9282fa924 --- /dev/null +++ b/test/grouped_gemm/test_grouped_gemm_splitk.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/data_type.hpp" + +#include "gtest/gtest.h" +#include "test_grouped_gemm_util.hpp" + +using F16 = ck::half_t; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using RRR_F16_F16_F16 = ck::test::TestGroupedGemm>; +using RCR_F16_F16_F16 = ck::test::TestGroupedGemm>; + +using RRR_F16_F16_F16_LargeK = ck::test::TestGroupedGemm>; +using RCR_F16_F16_F16_LargeK = ck::test::TestGroupedGemm>; + +const std::vector KBATCH{1, 2, 3, 5, 8}; + +INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_KN, RRR_F16_F16_F16, testing::ValuesIn(KBATCH)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_NK, RCR_F16_F16_F16, testing::ValuesIn(KBATCH)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_LargeK_MK_KN, + RRR_F16_F16_F16_LargeK, + testing::Values(32, 64)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_LargeK_MK_NK, + RCR_F16_F16_F16_LargeK, + testing::Values(32, 64)); + +#include "test_grouped_gemm_ut_cases.inc" diff --git a/test/grouped_gemm/test_grouped_gemm_ut_cases.inc b/test/grouped_gemm/test_grouped_gemm_ut_cases.inc new file mode 100644 index 0000000000..d94d140d97 --- /dev/null +++ b/test/grouped_gemm/test_grouped_gemm_ut_cases.inc @@ -0,0 +1,180 @@ +#pragma once + +TEST_P(RRR_F16_F16_F16, TinyCases) +{ + const std::vector Ms{0, 1}; + constexpr int N = 768; + constexpr int K = 544; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), N); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RRR_F16_F16_F16, SmallCases) +{ + const std::vector Ms{2, 1, 3, 4, 5, 0}; + constexpr int N = 768; + constexpr int K = 544; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), N); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RRR_F16_F16_F16, MidCases) +{ + const std::vector Ms{167, 183, 177, 153, 139, 204}; + constexpr int N = 768; + constexpr int K = 544; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), N); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RRR_F16_F16_F16, Regular) +{ + const std::vector Ms{64, 128, 256}; + constexpr int N = 768; + constexpr int K = 320; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), N); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RRR_F16_F16_F16, MNKPadded) +{ + const std::vector Ms{127, 150, 188, 210}; + constexpr int N = 136; + constexpr int K = 280; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), N); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RCR_F16_F16_F16, TinyCases) +{ + const std::vector Ms{0, 1}; + constexpr int N = 768; + constexpr int K = 544; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), K); + const std::vector StrideCs(Ms.size(), N); + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RCR_F16_F16_F16, SmallCases) +{ + const std::vector Ms{2, 1, 3, 4, 5, 0}; + constexpr int N = 768; + constexpr int K = 544; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), K); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RCR_F16_F16_F16, MidCases) +{ + const std::vector Ms{167, 183, 177, 153, 139, 204}; + constexpr int N = 768; + constexpr int K = 544; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), K); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RCR_F16_F16_F16, Regular) +{ + const std::vector Ms{32, 64, 128, 256}; + constexpr int N = 768; + constexpr int K = 320; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), K); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RCR_F16_F16_F16, MNKPadded) +{ + const std::vector Ms{127, 150, 188, 210}; + constexpr int N = 136; + constexpr int K = 280; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), K); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch) +{ + const std::vector Ms{188, 210}; + constexpr int N = 768; + constexpr int K = 4096; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), N); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RCR_F16_F16_F16_LargeK, TestLargeKBatch) +{ + const std::vector Ms{188, 210}; + constexpr int N = 768; + constexpr int K = 4096; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), K); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} diff --git a/test/grouped_gemm/test_grouped_gemm_util.hpp b/test/grouped_gemm/test_grouped_gemm_util.hpp new file mode 100644 index 0000000000..b61118b512 --- /dev/null +++ b/test/grouped_gemm/test_grouped_gemm_util.hpp @@ -0,0 +1,249 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/stream_config.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/utility/number.hpp" +#include "profiler/profile_grouped_gemm_impl.hpp" + +namespace ck { +namespace test { + +template +std::string serialize_range(const Range& range) +{ + std::stringstream ss; + for(auto& r : range) + { + ss << r << ", "; + } + std::string str = ss.str(); + return std::string(str.begin(), str.end() - 2); +} + +template +class TestGroupedGemm : public testing::TestWithParam +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using ELayout = std::tuple_element_t<2, Tuple>; + using ADataType = std::tuple_element_t<3, Tuple>; + using BDataType = std::tuple_element_t<4, Tuple>; + using EDataType = std::tuple_element_t<5, Tuple>; + + public: + static constexpr bool verify_ = true; + static constexpr int init_method_ = 1; // decimal value initialization + static constexpr bool log_ = false; + static constexpr bool bench_ = false; // measure kernel performance + + void SetUp() override {} + + void Run(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs, + int kbatch = 1) + { + bool pass = ck::profiler::profile_grouped_gemm_impl( + verify_, init_method_, log_, bench_, Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch); + EXPECT_TRUE(pass); + } +}; + +template +struct DeviceGroupedGemmSplitkInstanceWrapper +{ + using F16 = half_t; + using F32 = float; + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + using PassThrough = tensor_operation::element_wise::PassThrough; + + using EmptyTuple = ck::Tuple<>; + + template + using S = ck::Sequence; + + template + using I = ck::Number; + + using ABlockTransferThreadClusterArrageOrder = + std::conditional_t, S<0, 2, 1, 3>, S<0, 1, 3, 2>>; + using ABlockTransferSrcAccessOrder = + std::conditional_t, S<0, 2, 1, 3>, S<0, 1, 3, 2>>; + using ABlockTransferSrcVectorDim = std::conditional_t, I<3>, I<2>>; + using ABlockTransferDstScalarPerVector_K1 = + std::conditional_t, I<8>, I<2>>; + using ABlockLdsAddExtraM = std::conditional_t, I<1>, I<0>>; + + using BBlockTransferThreadClusterArrageOrder = + std::conditional_t, S<0, 1, 3, 2>, S<0, 2, 1, 3>>; + using BBlockTransferSrcAccessOrder = + std::conditional_t, S<0, 1, 3, 2>, S<0, 2, 1, 3>>; + using BBlockTransferSrcVectorDim = std::conditional_t, I<2>, I<3>>; + using BBlockTransferDstScalarPerVector_K1 = + std::conditional_t, I<2>, I<8>>; + using BBlockLdsAddExtraM = std::conditional_t, I<0>, I<1>>; + + using DeviceGroupedGemmSplitKInstance = + tensor_operation::device::DeviceGroupedGemmXdlSplitKCShuffle< + ALayout, + BLayout, + EmptyTuple, + ELayout, + F16, + F16, + F32, + F16, + EmptyTuple, + F16, + PassThrough, + PassThrough, + PassThrough, + GemmSpec, + 1, + 128, + 128, + 128, + KPerBlock, + K1, + K1, + 32, + 32, + 4, + 2, + S<1, 4, 32, 1>, + ABlockTransferThreadClusterArrageOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim::value, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1::value, + ABlockLdsAddExtraM::value, + S<1, 4, 32, 1>, + BBlockTransferThreadClusterArrageOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim::value, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1::value, + BBlockLdsAddExtraM::value, + 1, + 1, + S<1, 16, 1, 8>, + CDEBlockTransferScalarPerVector_NPerBlock>; + + bool IsSupported(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs, + int kbatch = 1) const + { + std::size_t n_groups = Ms.size(); + EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups && + StrideBs.size() == n_groups && StrideCs.size() == n_groups) + << "The number of groups is not consistent!"; + + std::vector gemm_descs; + + for(std::size_t i = 0; i < n_groups; ++i) + { + gemm_descs.push_back(tensor_operation::device::GemmDesc{ + Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); + } + + std::vector p_As(n_groups, nullptr); + std::vector p_Bs(n_groups, nullptr); + std::vector p_Cs(n_groups, nullptr); + auto p_Ds = std::vector>{}; + + auto ggemm_instance = DeviceGroupedGemmSplitKInstance{}; + auto argument = ggemm_instance.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}); + if(kbatch > 1) + { + ggemm_instance.SetKBatchSize(argument, kbatch); + } + + return ggemm_instance.IsSupportedArgument(argument); + } + + float Run(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs, + int kbatch = 1) const + { + std::size_t n_groups = Ms.size(); + EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups && + StrideBs.size() == n_groups && StrideCs.size() == n_groups) + << "The number of groups is not consistent!"; + + std::vector gemm_descs; + + for(std::size_t i = 0; i < n_groups; ++i) + { + gemm_descs.push_back(tensor_operation::device::GemmDesc{ + Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); + } + + std::vector p_As(n_groups, nullptr); + std::vector p_Bs(n_groups, nullptr); + std::vector p_Cs(n_groups, nullptr); + auto p_Ds = std::vector>{}; + + auto ggemm_instance = DeviceGroupedGemmSplitKInstance{}; + auto argument = ggemm_instance.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}); + if(kbatch > 1) + { + ggemm_instance.SetKBatchSize(argument, kbatch); + } + + EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument)); + auto invoker = ggemm_instance.MakeInvoker(); + DeviceMem gemm_desc_workspace(ggemm_instance.GetWorkSpaceSize(&argument)); + ggemm_instance.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer()); + return invoker.Run(argument, StreamConfig{nullptr, false}); + } +}; + +} // namespace test +} // namespace ck