From 25d05d36c4587eef2ae8d395cfc7fc142772d363 Mon Sep 17 00:00:00 2001 From: ltqin Date: Thu, 3 Feb 2022 12:47:27 +0800 Subject: [PATCH] add split-k GEMM (#59) * add DeviceGemmSplitKXdl * add file device_gemm_splitk_xdl.hpp * set c matrix zero * using atomic * add all tuning parameter to f32 mkkn * grid size change to 720 * add tunning parameter for NT * add tunning parameter for TN * add tunning parameter for TT * add m=96tunning parameter * add lost config * add element wise operation * fixed MPerBlock=96 * remove marco for slpitk swtich * add test * add new line at the end of device_gemm_xdl_instance.hpp * remove step hack * seperate split-k instance files * add tunning parameters * change disired grid size to parameters * remove slice length * add desiredgridsize parameter to ckProfiler * add losting file device_gemm_xdl_splitk_instance.hpp * change desired gride size to kbatch * format * format * clean up * add selection of device_instances * clean code * fix build issue Co-authored-by: ltqin Co-authored-by: Chao Liu Co-authored-by: Jing Zhang [ROCm/composable_kernel commit: 4be7f0198e55f386d51cdb127dc0fa69427d6fe0] --- .../gridwise_gemm_xdlops_v2r4.hpp | 46 +- ...emm_xdl_f16_f16_f16_km_kn_mn_instance.cpp} | 21 +- ...emm_xdl_f16_f16_f16_km_nk_mn_instance.cpp} | 21 +- ...emm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp} | 21 +- ...emm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp} | 21 +- ...emm_xdl_f32_f32_f32_km_kn_mn_instance.cpp} | 21 +- ...emm_xdl_f32_f32_f32_km_nk_mn_instance.cpp} | 21 +- ...emm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp} | 21 +- ...emm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp} | 21 +- ...l_splitk_f32_f32_f32_km_kn_mn_instance.cpp | 51 ++ ...l_splitk_f32_f32_f32_km_nk_mn_instance.cpp | 51 ++ ...l_splitk_f32_f32_f32_mk_kn_mn_instance.cpp | 51 ++ ...l_splitk_f32_f32_f32_mk_nk_mn_instance.cpp | 56 ++ device_operation/include/device_gemm.hpp | 26 +- .../include/device_gemm_instance.hpp | 27 - device_operation/include/device_gemm_xdl.hpp | 3 +- .../include/device_gemm_xdl_splitk.hpp | 606 ++++++++++++++++++ profiler/CMakeLists.txt | 23 +- profiler/include/profile_gemm_impl.hpp | 198 +++--- profiler/profile_gemm.cpp | 22 +- test/CMakeLists.txt | 9 +- test/split_k/main.cpp | 218 +++++++ 22 files changed, 1279 insertions(+), 276 deletions(-) rename device_operation/{device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp => device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp} (90%) rename device_operation/{device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp => device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp} (90%) rename device_operation/{device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp => device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp} (90%) rename device_operation/{device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp => device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp} (93%) rename device_operation/{device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp => device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp} (90%) rename device_operation/{device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp => device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp} (90%) rename device_operation/{device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp => device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp} (90%) rename device_operation/{device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp => device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp} (93%) create mode 100644 device_operation/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp create mode 100644 device_operation/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp create mode 100644 device_operation/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp create mode 100644 device_operation/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp delete mode 100644 device_operation/include/device_gemm_instance.hpp create mode 100644 device_operation/include/device_gemm_xdl_splitk.hpp create mode 100644 test/split_k/main.cpp diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp index 39a910a6ff..7983b0e834 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp @@ -62,7 +62,10 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS @@ -74,7 +77,10 @@ __global__ void const void CONSTANT* p_a_b_k0_m_k1_grid_desc, const void CONSTANT* p_b_b_k0_n_k1_grid_desc, const void CONSTANT* p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - const void CONSTANT* p_c_block_cluster_adaptor) + const void CONSTANT* p_a_element_op, + const void CONSTANT* p_b_element_op, + const void CONSTANT* p_c_element_op, + const void CONSTANT* p_block_2_ctile_map) { constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); @@ -86,8 +92,14 @@ __global__ void const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = *reinterpret_cast( cast_pointer_to_generic_address_space(p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc)); - const auto c_block_cluster_adaptor = *reinterpret_cast( - cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor)); + const auto block_2_ctile_map = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_block_2_ctile_map)); + const auto a_element_op = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_a_element_op)); + const auto b_element_op = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_b_element_op)); + const auto c_element_op = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_c_element_op)); __shared__ FloatAB p_shared_block[shared_block_size]; @@ -98,7 +110,10 @@ __global__ void a_b_k0_m_k1_grid_desc, b_b_k0_n_k1_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_block_cluster_adaptor); + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); } #endif @@ -110,6 +125,9 @@ template + index_t CThreadTransferDstScalarPerVector> struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 { static constexpr auto I0 = Number<0>{}; @@ -358,6 +373,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc, const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc, const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, const CBlockClusterAdaptor& c_block_cluster_adaptor) { const auto a_grid_buf = make_dynamic_buffer( @@ -456,7 +474,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ck::tensor_operation::element_wise::PassThrough, InMemoryDataOperationEnum_t::Set, Sequence<1, K0PerBlock, MPerBlock, K1>, - ABlockTransferThreadSliceLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, FloatAB, @@ -487,7 +504,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ck::tensor_operation::element_wise::PassThrough, InMemoryDataOperationEnum_t::Set, Sequence<1, K0PerBlock, NPerBlock, K1>, - BBlockTransferThreadSliceLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, FloatAB, @@ -583,8 +599,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); - k_block_data_begin += K0PerBlock; - } while(k_block_data_begin < (K0 - K0PerBlock)); + k0_block_data_begin += K0PerBlock; + } while(k0_block_data_begin < (K0 - K0PerBlock)); } // tail diff --git a/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp b/device_operation/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp similarity index 90% rename from device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp rename to device_operation/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp index 78f5352f7e..f8ff5406d5 100644 --- a/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp +++ b/device_operation/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp @@ -1,8 +1,8 @@ #include #include "config.hpp" #include "device_gemm_xdl.hpp" -#include "device_gemm_instance.hpp" #include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" namespace ck { namespace tensor_operation { @@ -21,7 +21,7 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_xdl_instance_f16_f16_f16_km_kn_mn = +using device_gemm_xdl_f16_f16_f16_km_kn_mn_instances = std::tuple< // clang-format off //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| @@ -39,21 +39,10 @@ using device_gemm_xdl_instance_f16_f16_f16_km_kn_mn = // clang-format on >; -template <> -void add_device_gemm_instance( - std::vector>& device_op_instances) +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances( + std::vector>& instances) { - using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f16_f16_f16_km_kn_mn; - - const auto device_gemms = DeviceGemms{}; - - ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { - using Gemm = remove_cvref_t(device_gemms))>; - - auto gemm = Gemm{}; - - device_op_instances.push_back(std::make_unique(gemm)); - }); + add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_km_kn_mn_instances{}); } } // namespace device_gemm_instance diff --git a/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp b/device_operation/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp similarity index 90% rename from device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp rename to device_operation/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp index 786c4ab1e1..8fa9c0b66a 100644 --- a/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp +++ b/device_operation/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp @@ -1,8 +1,8 @@ #include #include "config.hpp" #include "device_gemm_xdl.hpp" -#include "device_gemm_instance.hpp" #include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" namespace ck { namespace tensor_operation { @@ -21,7 +21,7 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using device_gemm_xdl_instance_f16_f16_f16_km_nk_mn = +using device_gemm_xdl_f16_f16_f16_km_nk_mn_instances = std::tuple< // clang-format off //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| @@ -39,21 +39,10 @@ using device_gemm_xdl_instance_f16_f16_f16_km_nk_mn = // clang-format on >; -template <> -void add_device_gemm_instance( - std::vector>& device_op_instances) +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances( + std::vector>& instances) { - using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f16_f16_f16_km_nk_mn; - - const auto device_gemms = DeviceGemms{}; - - ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { - using Gemm = remove_cvref_t(device_gemms))>; - - auto gemm = Gemm{}; - - device_op_instances.push_back(std::make_unique(gemm)); - }); + add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_km_nk_mn_instances{}); } } // namespace device_gemm_instance diff --git a/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp b/device_operation/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp similarity index 90% rename from device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp rename to device_operation/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp index 44459ca4cb..692319a4e9 100644 --- a/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp +++ b/device_operation/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp @@ -1,8 +1,8 @@ #include #include "config.hpp" #include "device_gemm_xdl.hpp" -#include "device_gemm_instance.hpp" #include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" namespace ck { namespace tensor_operation { @@ -21,7 +21,7 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn = +using device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = std::tuple< // clang-format off //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| @@ -39,21 +39,10 @@ using device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn = // clang-format on >; -template <> -void add_device_gemm_instance( - std::vector>& device_op_instances) +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( + std::vector>& instances) { - using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn; - - const auto device_gemms = DeviceGemms{}; - - ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { - using Gemm = remove_cvref_t(device_gemms))>; - - auto gemm = Gemm{}; - - device_op_instances.push_back(std::make_unique(gemm)); - }); + add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances{}); } } // namespace device_gemm_instance diff --git a/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp b/device_operation/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp similarity index 93% rename from device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp rename to device_operation/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp index 7286dfe598..cbf2020df1 100644 --- a/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp +++ b/device_operation/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -1,8 +1,8 @@ #include #include "config.hpp" #include "device_gemm_xdl.hpp" -#include "device_gemm_instance.hpp" #include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" namespace ck { namespace tensor_operation { @@ -21,7 +21,7 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for a[m, k] * b[n, k] = c[m, n] -using device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn = +using device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances = std::tuple< // clang-format off //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| @@ -44,21 +44,10 @@ using device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn = // clang-format on >; -template <> -void add_device_gemm_instance( - std::vector>& device_op_instances) +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( + std::vector>& instances) { - using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn; - - const auto device_gemms = DeviceGemms{}; - - ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { - using Gemm = remove_cvref_t(device_gemms))>; - - auto gemm = Gemm{}; - - device_op_instances.push_back(std::make_unique(gemm)); - }); + add_device_operation_instances(instances, device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances{}); } } // namespace device_gemm_instance diff --git a/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp b/device_operation/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp similarity index 90% rename from device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp rename to device_operation/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp index 344f182fa3..d893209a61 100644 --- a/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp +++ b/device_operation/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp @@ -1,8 +1,8 @@ #include #include "config.hpp" #include "device_gemm_xdl.hpp" -#include "device_gemm_instance.hpp" #include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" namespace ck { namespace tensor_operation { @@ -21,7 +21,7 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_xdl_instance_f32_f32_f32_km_kn_mn = +using device_gemm_xdl_f32_f32_f32_km_kn_mn_instances = std::tuple< // clang-format off //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| @@ -39,21 +39,10 @@ using device_gemm_xdl_instance_f32_f32_f32_km_kn_mn = // clang-format on >; -template <> -void add_device_gemm_instance( - std::vector>& device_op_instances) +void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances( + std::vector>& instances) { - using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_km_kn_mn; - - const auto device_gemms = DeviceGemms{}; - - ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { - using Gemm = remove_cvref_t(device_gemms))>; - - auto gemm = Gemm{}; - - device_op_instances.push_back(std::make_unique(gemm)); - }); + add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_km_kn_mn_instances{}); } } // namespace device_gemm_instance diff --git a/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp b/device_operation/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp similarity index 90% rename from device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp rename to device_operation/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp index fb17e0aaea..036c1aeb3c 100644 --- a/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp +++ b/device_operation/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp @@ -1,8 +1,8 @@ #include #include "config.hpp" #include "device_gemm_xdl.hpp" -#include "device_gemm_instance.hpp" #include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" namespace ck { namespace tensor_operation { @@ -21,7 +21,7 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using device_gemm_xdl_instance_f32_f32_f32_km_nk_mn = +using device_gemm_xdl_f32_f32_f32_km_nk_mn_instances = std::tuple< // clang-format off //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| @@ -39,21 +39,10 @@ using device_gemm_xdl_instance_f32_f32_f32_km_nk_mn = // clang-format on >; -template <> -void add_device_gemm_instance( - std::vector>& device_op_instances) +void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances( + std::vector>& instances) { - using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_km_nk_mn; - - const auto device_gemms = DeviceGemms{}; - - ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { - using Gemm = remove_cvref_t(device_gemms))>; - - auto gemm = Gemm{}; - - device_op_instances.push_back(std::make_unique(gemm)); - }); + add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_km_nk_mn_instances{}); } } // namespace device_gemm_instance diff --git a/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp b/device_operation/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp similarity index 90% rename from device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp rename to device_operation/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp index 7567a8c2ec..7379493fbe 100644 --- a/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp +++ b/device_operation/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp @@ -1,8 +1,8 @@ #include #include "config.hpp" #include "device_gemm_xdl.hpp" -#include "device_gemm_instance.hpp" #include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" namespace ck { namespace tensor_operation { @@ -21,7 +21,7 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = +using device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances = std::tuple< // clang-format off //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| @@ -39,21 +39,10 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = // clang-format on >; -template <> -void add_device_gemm_instance( - std::vector>& device_op_instances) +void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances( + std::vector>& instances) { - using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn; - - const auto device_gemms = DeviceGemms{}; - - ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { - using Gemm = remove_cvref_t(device_gemms))>; - - auto gemm = Gemm{}; - - device_op_instances.push_back(std::make_unique(gemm)); - }); + add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances{}); } } // namespace device_gemm_instance diff --git a/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp b/device_operation/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp similarity index 93% rename from device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp rename to device_operation/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp index 6c80f0d9f4..b474262823 100644 --- a/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp +++ b/device_operation/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp @@ -1,8 +1,8 @@ #include #include "config.hpp" #include "device_gemm_xdl.hpp" -#include "device_gemm_instance.hpp" #include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" namespace ck { namespace tensor_operation { @@ -21,7 +21,7 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for a[m, k] * b[n, k] = c[m, n] -using device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn = +using device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances = std::tuple< // clang-format off //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| @@ -44,21 +44,10 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn = // clang-format on >; -template <> -void add_device_gemm_instance( - std::vector>& device_op_instances) +void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances( + std::vector>& instances) { - using DeviceGemms = device_gemm_instance::device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn; - - const auto device_gemms = DeviceGemms{}; - - ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) { - using Gemm = remove_cvref_t(device_gemms))>; - - auto gemm = Gemm{}; - - device_op_instances.push_back(std::make_unique(gemm)); - }); + add_device_operation_instances(instances, device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances{}); } } // namespace device_gemm_instance diff --git a/device_operation/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp b/device_operation/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..5d548bfc26 --- /dev/null +++ b/device_operation/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,51 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_splitk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp b/device_operation/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..b0218fd027 --- /dev/null +++ b/device_operation/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,51 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_splitk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp b/device_operation/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..524fd364c2 --- /dev/null +++ b/device_operation/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,51 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_splitk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 96, 128, 4, 8, 16, 16, 3, 4, S<1, 4, 32, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 4, true, 7, 1> + >; + +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp b/device_operation/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..f2526e131d --- /dev/null +++ b/device_operation/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,56 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_splitk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1>, + DeviceGemmXdlSplitK< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/include/device_gemm.hpp b/device_operation/include/device_gemm.hpp index cf45829ca4..5b386bd908 100644 --- a/device_operation/include/device_gemm.hpp +++ b/device_operation/include/device_gemm.hpp @@ -13,19 +13,19 @@ template struct DeviceGemm : public BaseOperator { - virtual std::unique_ptr - MakeArgumentPointer(const void* p_a, - const void* p_b, - void* p_c, - ck::index_t M, - ck::index_t N, - ck::index_t K, - ck::index_t StrideA, - ck::index_t StrideB, - ck::index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) = 0; + virtual std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch = 1) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; }; diff --git a/device_operation/include/device_gemm_instance.hpp b/device_operation/include/device_gemm_instance.hpp deleted file mode 100644 index 1edaf090dd..0000000000 --- a/device_operation/include/device_gemm_instance.hpp +++ /dev/null @@ -1,27 +0,0 @@ -#ifndef DEVICE_GEMM_INSTANTCE_HPP -#define DEVICE_GEMM_INSTANTCE_HPP - -#include "device_gemm.hpp" -#include "element_wise_operation.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { - -template -void add_device_gemm_instance( - std::vector>&); - -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck -#endif diff --git a/device_operation/include/device_gemm_xdl.hpp b/device_operation/include/device_gemm_xdl.hpp index 9e5ee80381..927084815b 100644 --- a/device_operation/include/device_gemm_xdl.hpp +++ b/device_operation/include/device_gemm_xdl.hpp @@ -408,7 +408,8 @@ struct DeviceGemmXdl index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) override + CElementwiseOperation c_element_op, + ck::index_t) override { return std::make_unique(static_cast(p_a), static_cast(p_b), diff --git a/device_operation/include/device_gemm_xdl_splitk.hpp b/device_operation/include/device_gemm_xdl_splitk.hpp new file mode 100644 index 0000000000..ed29d40ab0 --- /dev/null +++ b/device_operation/include/device_gemm_xdl_splitk.hpp @@ -0,0 +1,606 @@ +#ifndef DEVICE_GEMM_SPLITK_XDL_HPP +#define DEVICE_GEMM_SPLITK_XDL_HPP + +#include +#include +#include "device.hpp" +#include "device_base.hpp" +#include "device_gemm.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r4.hpp" + +#ifndef CK_RUN_KERNEL_AND_TIME +#define CK_RUN_KERNEL_AND_TIME 1 +#endif + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmXdlSplitK + : public DeviceGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto K1Number = Number{}; + + static auto + MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, index_t K, index_t StrideA, int KBatch, int KPad) + { + assert(KPad % (K1 * KBatch) == 0); + + const index_t K0 = KPad / (K1 * KBatch); + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(M)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto a_grid_desc_kbatch_k0_m_k1 = transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + return a_grid_desc_kbatch_k0_m_k1; + } + + static auto + MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, index_t N, index_t StrideB, int KBatch, int KPad) + { + assert(KPad % (K1 * KBatch) == 0); + + const index_t K0 = KPad / (K1 * KBatch); + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_kbatch_k0_n_k1 = transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + return b_grid_desc_kbatch_k0_n_k1; + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + } + + static auto GetKPad(index_t K, index_t KBatch) + { + const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; + const index_t KPad = KBatch * K0 * K1; + return KPad; + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_KBatch_K0_M_K1(1, 1, 1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_KBatch_K0_N_K1(1, 1, 1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum_t::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; + + // GridwiseGemm + using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum_t::AtomicAdd, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; + + using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = + decltype(GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(CGridDesc_M_N{})); + + using Block2CTileMap = + decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t k_batch) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + k_batch_{k_batch} + { + int KPad = DeviceGemmXdlSplitK::GetKPad(K, k_batch_); + + a_grid_desc_kbatch_k0_m_k1_ = DeviceGemmXdlSplitK::MakeAGridDescriptor_KBatch_K0_M_K1( + M, K, StrideA, k_batch_, KPad); + b_grid_desc_kbatch_k0_n_k1_ = DeviceGemmXdlSplitK::MakeBGridDescriptor_KBatch_K0_N_K1( + K, N, StrideB, k_batch_, KPad); + c_grid_desc_m_n_ = DeviceGemmXdlSplitK::MakeCGridDescriptor_M_N(M, N, StrideC); + + if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, + b_grid_desc_kbatch_k0_n_k1_, + c_grid_desc_m_n_, + M01_, + N01_)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_grid_desc_m_n_); + + block_2_ctile_map_ = + GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGemmXdlSplitK::Argument; + + void ShowInfo(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + float Run(const Argument& arg, int nrepeat = 1) + { + const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_, kbatch); + + const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + + const auto Run = [&](const auto& kernel) { + if(nrepeat > 0) + { + ShowInfo(arg); + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + + if(kbatch > 1 || nrepeat <= 0) + { + hipGetErrorString( + hipMemset(arg.p_c_grid_, + 0, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_.GetElementSpaceSize() * + sizeof(CDataType))); + + launch_kernel(kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + }; + if(has_main_k0_block_loop) + { + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_v2r4< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r4< + GridwiseGemmAtomicAdd, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; + + Run(kernel); + } + } + else + { + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_v2r4< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r4< + GridwiseGemmAtomicAdd, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; + + Run(kernel); + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, int nrepeat = 1) override + { + return Run(*dynamic_cast(p_arg), nrepeat); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t KBatch) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op, + KBatch}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch = 1) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op, + KBatch); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmXdlSplitK" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt index 6ef9cd6014..7de9e1a378 100644 --- a/profiler/CMakeLists.txt +++ b/profiler/CMakeLists.txt @@ -14,14 +14,18 @@ include_directories(BEFORE # device_gemm_instance set(DEVICE_GEMM_INSTANCE_SOURCE - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_nk_mn.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_kn_mn.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f32_f32_f32_km_nk_mn.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_kn_mn.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_mk_nk_mn.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_kn_mn.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_instance_f16_f16_f16_km_nk_mn.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp; ) add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE}) @@ -83,7 +87,8 @@ set(PROFILER_SOURCE profile_conv_fwd.cpp profile_conv_fwd_bias_relu.cpp profile_conv_fwd_bias_relu_add.cpp - profile_conv_fwd_bias_relu_atomic_add.cpp) + profile_conv_fwd_bias_relu_atomic_add.cpp + ) add_executable(ckProfiler ${PROFILER_SOURCE}) target_link_libraries(ckProfiler PRIVATE host_tensor) diff --git a/profiler/include/profile_gemm_impl.hpp b/profiler/include/profile_gemm_impl.hpp index 3e99928fa4..596770190b 100644 --- a/profiler/include/profile_gemm_impl.hpp +++ b/profiler/include/profile_gemm_impl.hpp @@ -1,78 +1,29 @@ #pragma once -#include "device_gemm_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace device_gemm_instance { -using DeviceGemmNoOpPtr = DeviceGemmPtr; +using DeviceGemmNoOpPtr = + ck::tensor_operation::device::DeviceGemmPtr; -template <> -void add_device_gemm_instance(std::vector&); +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector&); -template <> -void add_device_gemm_instance(std::vector&); +void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(std::vector&); -template <> -void add_device_gemm_instance(std::vector&); - -template <> -void add_device_gemm_instance(std::vector&); - -template <> -void add_device_gemm_instance(std::vector&); - -template <> -void add_device_gemm_instance(std::vector&); - -template <> -void add_device_gemm_instance(std::vector&); - -template <> -void add_device_gemm_instance(std::vector&); +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector&); } // namespace device_gemm_instance } // namespace device @@ -97,7 +48,8 @@ void profile_gemm_impl(int do_verification, int K, int StrideA, int StrideB, - int StrideC) + int StrideC, + int KBatch = 1) { auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { @@ -122,17 +74,20 @@ void profile_gemm_impl(int do_verification, std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::size_t num_thread = std::thread::hardware_concurrency(); switch(init_method) { case 0: break; case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); break; default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); } + // set zero to c_device_buf + c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0{}, num_thread); if(do_verification) { @@ -155,9 +110,103 @@ void profile_gemm_impl(int do_verification, // add device GEMM instances std::vector gemm_ptrs; - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_instance( - gemm_ptrs); + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + if(KBatch > 1) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); + } + else + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); + } + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + if(KBatch > 1) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); + } + else + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); + } + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + if(KBatch > 1) + { + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); + } + else + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); + } + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + if(KBatch > 1) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); + } + else + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); + } + } + } + else if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); + } + } if(gemm_ptrs.size() <= 0) { @@ -184,7 +233,8 @@ void profile_gemm_impl(int do_verification, StrideC, ck::tensor_operation::element_wise::PassThrough{}, ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}); + ck::tensor_operation::element_wise::PassThrough{}, + KBatch); auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); diff --git a/profiler/profile_gemm.cpp b/profiler/profile_gemm.cpp index c34c3376f4..37d5b4f2ee 100644 --- a/profiler/profile_gemm.cpp +++ b/profiler/profile_gemm.cpp @@ -35,19 +35,20 @@ enum GemmDataType int profile_gemm(int argc, char* argv[]) { - if(argc != 14) + if(!(argc == 14 || argc == 15)) { printf("arg1: tensor operation (gemm: GEMM)\n"); printf("arg2: data type (0: fp32; 1: fp16)\n"); printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); - printf(" 2: A[k, n] * B[k, n] = C[m, n];\n"); - printf(" 3: A[k, n] * B[n, k] = C[m, n])\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); printf("arg4: verification (0: no; 1: yes)\n"); printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); printf("arg8: print tensor value (0: no; 1: yes)\n"); printf("arg7: run kernel # of times (>1)\n"); printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); + printf("arg14: split k into mulitiple batch\n"); exit(1); } @@ -65,6 +66,9 @@ int profile_gemm(int argc, char* argv[]) const int StrideA = std::stoi(argv[11]); const int StrideB = std::stoi(argv[12]); const int StrideC = std::stoi(argv[13]); + int KBatch = 1; + if(argc == 15) + KBatch = std::stoi(argv[14]); if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) { @@ -159,7 +163,8 @@ int profile_gemm(int argc, char* argv[]) K, (StrideA < 0) ? K : StrideA, (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC); + (StrideC < 0) ? N : StrideC, + KBatch); } else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) { @@ -178,7 +183,8 @@ int profile_gemm(int argc, char* argv[]) K, (StrideA < 0) ? K : StrideA, (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC); + (StrideC < 0) ? N : StrideC, + KBatch); } else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) { @@ -197,7 +203,8 @@ int profile_gemm(int argc, char* argv[]) K, (StrideA < 0) ? M : StrideA, (StrideB < 0) ? N : StrideB, - (StrideC < 0) ? N : StrideC); + (StrideC < 0) ? N : StrideC, + KBatch); } else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) { @@ -216,7 +223,8 @@ int profile_gemm(int argc, char* argv[]) K, (StrideA < 0) ? M : StrideA, (StrideB < 0) ? K : StrideB, - (StrideC < 0) ? N : StrideC); + (StrideC < 0) ? N : StrideC, + KBatch); } else { diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index c74349d76c..1b3e1e57e5 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -11,8 +11,13 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/external/rocm/include ) +# test_magic_number_division set(MAGIC_NUMBER_DIVISISON_SOURCE magic_number_division/main.cpp) - add_executable(test_magic_number_division ${MAGIC_NUMBER_DIVISISON_SOURCE}) - target_link_libraries(test_magic_number_division PRIVATE host_tensor) + +# test_split_k +set(SPLIT_K_SOURCE split_k/main.cpp) +add_executable(test_split_k ${SPLIT_K_SOURCE}) +target_link_libraries(test_split_k PRIVATE host_tensor) +target_link_libraries(test_split_k PRIVATE device_gemm_instance) diff --git a/test/split_k/main.cpp b/test/split_k/main.cpp new file mode 100644 index 0000000000..3097f4e925 --- /dev/null +++ b/test/split_k/main.cpp @@ -0,0 +1,218 @@ +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "host_gemm.hpp" +#include "tensor_layout.hpp" +#include "device_gemm_xdl_splitk.hpp" + +enum GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 +}; + +using DeviceGemmNoOpPtr = + ck::tensor_operation::device::DeviceGemmPtr; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +template +static bool check_out(const Tensor& ref, const Tensor& result) +{ + float max_diff = 1e-6; + + for(int i = 0; i < ref.mData.size(); ++i) + { + float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); + if(max_diff < diff) + { + return false; + } + } + + return true; +} + +int main(int argc, char* argv[]) +{ + if(argc != 9) + { + printf("arg1: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg2 to 7: M, N, K, StrideA, StrideB, StrideC KBatch\n"); + return 1; + } + + const int layout = static_cast(std::stoi(argv[1])); + + const int M = std::stoi(argv[2]); + const int N = std::stoi(argv[3]); + const int K = std::stoi(argv[4]); + + const int StrideA = std::stoi(argv[5]); + const int StrideB = std::stoi(argv[6]); + const int StrideC = std::stoi(argv[7]); + const int KBatch = std::stoi(argv[8]); + + bool a_row_major, b_row_major, c_row_major; + + switch(layout) + { + case GemmMatrixLayout::MK_KN_MN: + a_row_major = true; + b_row_major = true; + c_row_major = true; + break; + case GemmMatrixLayout::MK_NK_MN: + a_row_major = true; + b_row_major = false; + c_row_major = true; + break; + case GemmMatrixLayout::KM_KN_MN: + a_row_major = false; + b_row_major = true; + c_row_major = true; + break; + case GemmMatrixLayout::KM_NK_MN: + a_row_major = false; + b_row_major = false; + c_row_major = true; + break; + default: printf("not supported layout"); return 1; + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, bool row_major) { + if(row_major) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, a_row_major)); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, b_row_major)); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, c_row_major)); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, c_row_major)); + + // init data + std::size_t num_thread = std::thread::hardware_concurrency(); + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + // set zero to c_device_buf + c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + + host_gemm_mk_kn_mn(a_m_k, + b_k_n, + c_m_n_host_result, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}); + + DeviceMem a_device_buf(sizeof(float) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(float) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(float) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + c_device_buf.ToDevice(c_m_n_device_result.mData.data()); + + // add device GEMM instances + std::vector gemm_ptrs; + + if(layout == GemmMatrixLayout::MK_KN_MN) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); + } + else if(layout == GemmMatrixLayout::MK_NK_MN) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); + } + else if(layout == GemmMatrixLayout::KM_KN_MN) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); + } + else + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); + } + + bool success = false; + for(auto& gemm_ptr : gemm_ptrs) + { + auto argument_ptr = + gemm_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + KBatch); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), 0); + + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + if(!check_out(c_m_n_host_result, c_m_n_device_result)) + { + success = false; + break; + } + success = true; + } + } + + if(success) + { + std::cout << "test split k : Pass" << std::endl; + } + else + { + std::cout << "test split k: Fail " << std::endl; + } + return 0; +}