From 974524f2c1f306fbb8aeb0279fc22caf67f0a971 Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Wed, 27 Nov 2024 13:02:44 +0100 Subject: [PATCH] Polished Grouped GEMM APIs and new BF16 instances (#1600) * Few small fixes. * New GroupedGemm instances (BF16) * Unify and refactor GroupedGEMM device API. * Adapt changes to new API. * Adapt grouped gemm profiler. * Accept multiple kbatches for grouped gemm profiler. - delete obsolete two stage as it is now covered by grouped gemm * Update unit test for grouped gemm. * Fix thresholds for BF16 and F8. Unblock tests. * Fix few instances. * Multiple small fixes. * Adapt to new API, check dynamic casting. * Uncomment few data types in grouped gemm profiler. * Fix call to SetDeviceArgs. * Fix profile grouped gemm multiply tile loop. * Fix grouped gemm tile loop kernel args in client examples. * Review comments. [ROCm/composable_kernel commit: 061ac0649c75deb315a418466d00dea2c49e65f3] --- ...emm_multiply_bias_fastgelu_xdl_bf16_i8.cpp | 2 +- .../grouped_gemm_multiply_xdl_bf16_i8.cpp | 2 +- ...rouped_gemm_multiple_d_splitk_xdl_fp16.cpp | 4 +- .../grouped_gemm_multiple_d_xdl_fp16.cpp | 2 +- .../grouped_gemm_xdl_fixed_nk_bias_fp16.cpp | 4 +- .../grouped_gemm_xdl_fixed_nk_fp16.cpp | 4 +- .../grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp | 4 +- .../run_grouped_gemm_example.inc | 18 +- .../gpu/device/device_grouped_gemm.hpp | 132 ++++++- .../device/device_grouped_gemm_fixed_nk.hpp | 50 +-- .../device_grouped_gemm_multiple_d_splitk.hpp | 136 ------- .../gpu/device/device_grouped_gemm_splitk.hpp | 20 +- .../device/device_grouped_gemm_tile_loop.hpp | 92 +---- ...ltiple_d_splitk_xdl_cshuffle_two_stage.hpp | 99 +++-- ...gemm_multiple_d_xdl_cshuffle_tile_loop.hpp | 24 +- .../device/impl/device_grouped_gemm_xdl.hpp | 21 +- .../impl/device_grouped_gemm_xdl_fixed_nk.hpp | 72 +++- ...evice_grouped_gemm_xdl_splitk_cshuffle.hpp | 35 +- include/ck/utility/loop_scheduler.hpp | 1 - .../gpu/grouped_gemm.hpp | 185 ++++++++- ...evice_grouped_gemm_xdl_splitk_instance.hpp | 138 +++++++ .../gpu/grouped_gemm/CMakeLists.txt | 22 +- ..._bf16_bf16_bf16_km_kn_mn_irregular_pv1.cpp | 32 ++ ...bf16_bf16_km_kn_mn_irregular_pv1_inter.cpp | 36 ++ ..._bf16_bf16_bf16_km_kn_mn_irregular_pv2.cpp | 33 ++ ..._bf16_bf16_bf16_mk_kn_mn_irregular_pv1.cpp | 32 ++ ...bf16_bf16_mk_kn_mn_irregular_pv1_inter.cpp | 36 ++ ..._bf16_bf16_bf16_mk_kn_mn_irregular_pv2.cpp | 38 ++ ..._bf16_bf16_bf16_mk_nk_mn_irregular_pv1.cpp | 32 ++ ...bf16_bf16_mk_nk_mn_irregular_pv1_inter.cpp | 36 ++ ..._bf16_bf16_bf16_mk_nk_mn_irregular_pv2.cpp | 33 ++ ...l_splitk_f16_f16_f16_mk_kn_mn_instance.cpp | 47 +-- ...16_f16_f16_mk_kn_mn_irregular_instance.cpp | 123 ------ ...itk_f16_f16_f16_mk_kn_mn_irregular_pv1.cpp | 32 ++ ...6_f16_f16_mk_kn_mn_irregular_pv1_inter.cpp | 36 ++ ...itk_f16_f16_f16_mk_kn_mn_irregular_pv2.cpp | 33 ++ ...l_splitk_f16_f16_f16_mk_nk_mn_instance.cpp | 51 +-- ...16_f16_f16_mk_nk_mn_irregular_instance.cpp | 55 +-- ...ultiply_bf16_i8_bf16_mk_kn_mn_instance.cpp | 234 ----------- .../profiler/profile_grouped_gemm_impl.hpp | 119 +++--- ...e_grouped_gemm_multiply_tile_loop_impl.hpp | 3 +- .../profile_grouped_gemm_tile_loop_impl.hpp | 2 +- .../profile_grouped_gemm_two_stage_impl.hpp | 367 ------------------ profiler/src/CMakeLists.txt | 1 - profiler/src/profile_grouped_gemm.cpp | 89 ++++- .../src/profile_grouped_gemm_fixed_nk.cpp | 8 +- .../src/profile_grouped_gemm_two_stage.cpp | 228 ----------- test/grouped_gemm/CMakeLists.txt | 6 - .../test_grouped_gemm_splitk_xdl.cpp | 42 +- .../test_grouped_gemm_ut_cases.inc | 133 +------ test/grouped_gemm/test_grouped_gemm_util.hpp | 139 +++---- 51 files changed, 1400 insertions(+), 1723 deletions(-) delete mode 100644 include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv1.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv1_inter.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv2.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv1.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv1_inter.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv2.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv1.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv1_inter.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv2.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv1.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv1_inter.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv2.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instance.cpp delete mode 100644 profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp delete mode 100644 profiler/src/profile_grouped_gemm_two_stage.cpp diff --git a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_bias_fastgelu_xdl_bf16_i8.cpp b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_bias_fastgelu_xdl_bf16_i8.cpp index 4b284c74d4..47d3e0abf9 100644 --- a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_bias_fastgelu_xdl_bf16_i8.cpp +++ b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_bias_fastgelu_xdl_bf16_i8.cpp @@ -121,7 +121,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co constexpr ck::index_t NumDTensor = 2; using GroupedGemmKernelArgument = - ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments; + ck::tensor_operation::device::GroupedGemmKernelArgument; std::vector grouped_gemm_kernel_args_; grouped_gemm_kernel_args_.reserve(group_count); diff --git a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_xdl_bf16_i8.cpp b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_xdl_bf16_i8.cpp index 6cc83e06f6..8c705d3bcc 100644 --- a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_xdl_bf16_i8.cpp +++ b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_xdl_bf16_i8.cpp @@ -120,7 +120,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co constexpr ck::index_t NumDTensor = 1; using GroupedGemmKernelArgument = - ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments; + ck::tensor_operation::device::GroupedGemmKernelArgument; std::vector grouped_gemm_kernel_args_; grouped_gemm_kernel_args_.reserve(group_count); diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp index ecff7b4713..8bbf8e629e 100644 --- a/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp @@ -246,7 +246,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co // do GEMM auto argument = gemm.MakeArgument( p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op); - gemm.SetKBatchSize(argument, config.k_batch); + gemm.SetKBatchSize(&argument, config.k_batch); if(!gemm.IsSupportedArgument(argument)) { throw std::runtime_error( @@ -257,7 +257,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); - gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); + gemm.SetDeviceKernelArgs(&argument, gemm_arg_dev_mem.GetDeviceBuffer()); invoker.Run(argument, StreamConfig{nullptr, false, 1}); diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp index 965a0e7e37..e7b2ee4173 100644 --- a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp @@ -91,7 +91,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co { auto group_count = problem_size.group_count; - using KernelArguments = ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments; + using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument; using GemmDesc = ck::tensor_operation::device::GemmDesc; // GEMM shape diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp index a193fc39ba..3b3ef508ce 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -254,7 +254,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co gemm.GetDeviceKernelArgSize(&argument), hipMemcpyHostToDevice)); - gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer()); + gemm.SetDeviceKernelArgs(&argument, gemm_kernel_args_dev.GetDeviceBuffer()); gemm.SetKBatch(argument, config.k_batch); invoker.Run(argument, StreamConfig{nullptr, false}); diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp index 1a2bcfb33e..c1043f419d 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -239,7 +239,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co "not support this GEMM problem"); } - gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); + gemm.SetDeviceKernelArgs(&argument, gemm_arg_dev_mem.GetDeviceBuffer()); gemm.SetKBatch(argument, config.k_batch); invoker.Run(argument, StreamConfig{nullptr, false}); diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp index 0a63a29843..c81874b066 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -240,7 +240,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co "not support this GEMM problem"); } - gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); + gemm.SetDeviceKernelArgs(&argument, gemm_arg_dev_mem.GetDeviceBuffer()); gemm.SetKBatch(argument, config.k_batch); invoker.Run(argument, StreamConfig{nullptr, false}); diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 320870e0de..7cb0588b82 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -168,9 +168,23 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co auto argument = gemm.MakeArgument( p_a, p_b, p_Ds, p_c, gemm_descs, a_element_op, b_element_op, c_element_op); - DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument)); + std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument); + std::size_t kargs_size = gemm.GetDeviceKernelArgSize(&argument); - gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer()); + DeviceMem gemm_workspace, gemm_kargs; + + // The following is necessary since TwoStage kernel is using additional memory both + // for Workspace and kernel arguments. + if(kargs_size > 0) + { + gemm_kargs.Realloc(kargs_size); + gemm.SetDeviceKernelArgs(&argument, gemm_kargs.GetDeviceBuffer()); + } + if(workspace_size > 0 && workspace_size != kargs_size) + { + gemm_workspace.Realloc(workspace_size); + gemm.SetWorkSpacePointer(&argument, gemm_workspace.GetDeviceBuffer()); + } if(!gemm.IsSupportedArgument(argument)) { diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp index 1e03405536..267a970ee5 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp @@ -1,17 +1,87 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include #include +#include +#include #include #include "device_base.hpp" +#include "ck/utility/ignore.hpp" namespace ck { namespace tensor_operation { namespace device { +/// +/// @brief Structure representing single GEMM problem arguments. +/// +/// The pointer to the vector of those structures is passed to the GroupedGEMM entry +/// point kernel. +/// +/// @tparam NumDTensor The number of D input tensors. +/// +template +struct GroupedGemmKernelArgument +{ + __host__ __device__ GroupedGemmKernelArgument(const void* p_a_grid_, + const void* p_b_grid_, + std::array p_ds_grid_, + void* p_e_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideE_) + : p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_ds_grid{p_ds_grid_}, + p_e_grid{p_e_grid_}, + M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideDs{StrideDs_}, + StrideE{StrideE_} + { + } + + const void* p_a_grid; + const void* p_b_grid; + std::array p_ds_grid; + void* p_e_grid; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + std::array StrideDs; + index_t StrideE; + + void Print() const + { + std::stringstream str; + for(auto sd : StrideDs) + str << sd << ","; + + std::cout << "arg {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SE:" << StrideE << ", " + << "SDs: {" << str.str() << "}" + << "}" << std::endl; + } +}; + struct GemmDesc { ck::index_t M_, N_, K_; @@ -48,6 +118,66 @@ struct DeviceGroupedGemm : public BaseOperator CElementwiseOperation c_element_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; + + //--------------------------------------------------------------------------------------------- + /// @brief Sets the device kernel arguments pointer and may copy data to device. + /// + /// TODO: Add which kernels are using this (TileLoop * FixedNK ??) + /// + /// @param p_arg The pointer to the Argument we're going to update. + /// @param[in] p_dev_kernel_args The pointer to the device memory which will contain kernel + /// arguments. + /// @param[in] p_host_kernel_args The pointer to the host memory which contains kernel + /// arguments that should be copied to device memory. + /// + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, + void* p_dev_kernel_args, + const void* p_host_kernel_args) const + { + ignore = p_arg; + ignore = p_dev_kernel_args; + ignore = p_host_kernel_args; + + std::ostringstream err; + err << "This function is not implemented by the kernel: " << this->GetTypeString() + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + //---------------------------------------------------------------------------------------------- + /// @brief Sets the device kernel arguments pointer and may copy data to device. + /// + /// @param p_arg The pointer to the Argument we're going to update. + /// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel + /// arguments. + /// + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const + { + ignore = p_arg; + ignore = p_dev_kernel_args; + + std::ostringstream err; + err << "This function is not implemented by the kernel: " << this->GetTypeString() + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + //---------------------------------------------------------------------------------------------- + /// @brief Gets the device kernel argument size. + /// + /// @param[in] p_arg The pointer to the Device op Argument. + /// + /// @return The device kernel argument size. + /// + virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const + { + ignore = p_arg; + + std::ostringstream err; + err << "This function is not implemented by the kernel: " << this->GetTypeString() + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp index fcb2ba6a4d..780a0c30c5 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp @@ -1,35 +1,14 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include -#include - -#include "device_grouped_gemm.hpp" +#include "device_grouped_gemm_splitk.hpp" namespace ck { namespace tensor_operation { namespace device { -template -struct GroupedGemmKernelArgument -{ - const void* p_a_grid; - const void* p_b_grid; - std::array p_ds_grid; - void* p_e_grid; - - index_t M; - index_t N; - index_t K; - - index_t StrideA; - index_t StrideB; - std::array StrideDs; - index_t StrideE; -}; - template -struct DeviceGroupedGemmFixedNK : DeviceGroupedGemm +struct DeviceGroupedGemmFixedNK : DeviceGroupedGemmSplitK { - virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const = 0; - virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0; - virtual void SetKBatch(BaseArgument* p_arg, index_t k_batch) const = 0; }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp deleted file mode 100644 index d91eac0730..0000000000 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp +++ /dev/null @@ -1,136 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include -#include - -#include "device_grouped_gemm.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -/// -/// @brief Structure representing single GEMM problem arguments. -/// -/// The pointer to the vector of those structures is passed to the GroupedGEMM entry -/// point kernel. -/// -/// @tparam NumDTensor The number of D input tensors. -/// -template -struct GroupedGemmMultipleDKernelArguments -{ - __host__ __device__ - GroupedGemmMultipleDKernelArguments(const void* p_a_grid_, - const void* p_b_grid_, - std::array p_ds_grid_, - void* p_e_grid_, - index_t M_, - index_t N_, - index_t K_, - index_t StrideA_, - index_t StrideB_, - std::array StrideDs_, - index_t StrideE_) - : p_a_grid{p_a_grid_}, - p_b_grid{p_b_grid_}, - p_ds_grid{p_ds_grid_}, - p_e_grid{p_e_grid_}, - M{M_}, - N{N_}, - K{K_}, - StrideA{StrideA_}, - StrideB{StrideB_}, - StrideDs{StrideDs_}, - StrideE{StrideE_} - { - } - - const void* p_a_grid; - const void* p_b_grid; - std::array p_ds_grid; - void* p_e_grid; - index_t M; - index_t N; - index_t K; - index_t StrideA; - index_t StrideB; - std::array StrideDs; - index_t StrideE; - - void Print() const - { - std::stringstream str; - for(auto sd : StrideDs) - str << sd << ","; - - std::cout << "arg {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SE:" << StrideE << ", " - << "SDs: {" << str.str() << "}" - << "}" << std::endl; - } -}; - -template -struct DeviceGroupedGemmMultipleDSplitK : public DeviceGroupedGemm -{ - //---------------------------------------------------------------------------------------------- - /// @brief Sets the k batch size. - /// - /// @param p_arg Pointer to the Argument we're going to change. - /// @param[in] kbatch The kbatch value. - /// - virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0; - - //---------------------------------------------------------------------------------------------- - /// @brief Sets the device kernel arguments pointer. - /// - /// @param p_arg The pointer to the Argument we're going to update. - /// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel - /// arguments. - /// - virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const = 0; - - //---------------------------------------------------------------------------------------------- - /// @brief Gets the device kernel argument size. - /// - /// @param[in] p_arg The pointer to the Device op Argument. - /// - /// @return The device kernel argument size. - /// - virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0; -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp index 06d180d30f..3ea6501902 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp @@ -1,6 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include -#include #include "device_grouped_gemm.hpp" @@ -31,7 +31,23 @@ struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm { + //---------------------------------------------------------------------------------------------- + /// @brief Sets the k batch size. + /// + /// @param p_arg Pointer to the Argument we're going to change. + /// @param[in] kbatch The kbatch value. + /// virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0; + //---------------------------------------------------------------------------------------------- + /// @brief Sets the k batch size. + /// + /// @param p_arg Pointer to the Argument we're going to change. + /// @param[in] kbatch The kbatch value. + /// + virtual void SetKBatch(BaseArgument* p_arg, index_t kbatch) const + { + this->SetKBatchSize(p_arg, kbatch); + }; }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp index c1030f31cc..712fbfd9e9 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp @@ -3,83 +3,20 @@ #pragma once -#include -#include -#include -#include - #include "device_grouped_gemm.hpp" namespace ck { namespace tensor_operation { namespace device { +/// @brief Grouped GEMM kernel using output Tile Looping algorithm /// -/// @brief Structure representing single GEMM problem arguments. +/// @par This kernel does not require any knowledge about input data sizes (GEMM M/N/K) +/// It requires only the number of groups to launch. Other information like +/// data pointers and GEMM sizes, packed into gemm kernel args may be all dynamic +/// (known only at kernel run-time). /// -/// The pointer to the vector of those structures is passed to the GroupedGEMM entry -/// point kernel. -/// -/// @tparam NumDTensor The number of D input tensors. -/// -template -struct GroupedGemmTileLoopKernelArguments -{ - __host__ __device__ - GroupedGemmTileLoopKernelArguments(const void* p_a_grid_, - const void* p_b_grid_, - std::array p_ds_grid_, - void* p_e_grid_, - index_t M_, - index_t N_, - index_t K_, - index_t StrideA_, - index_t StrideB_, - std::array StrideDs_, - index_t StrideE_) - : p_a_grid{p_a_grid_}, - p_b_grid{p_b_grid_}, - p_ds_grid{p_ds_grid_}, - p_e_grid{p_e_grid_}, - M{M_}, - N{N_}, - K{K_}, - StrideA{StrideA_}, - StrideB{StrideB_}, - StrideDs{StrideDs_}, - StrideE{StrideE_} - { - } - - const void* p_a_grid; - const void* p_b_grid; - std::array p_ds_grid; - void* p_e_grid; - index_t M; - index_t N; - index_t K; - index_t StrideA; - index_t StrideB; - std::array StrideDs; - index_t StrideE; - - void Print() const - { - std::stringstream str; - for(auto sd : StrideDs) - str << sd << ","; - - std::cout << "arg {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SE:" << StrideE << ", " - << "SDs: {" << str.str() << "}" - << "}" << std::endl; - } -}; +/// @note This kernel does not support SplitK. template { - //---------------------------------------------------------------------------------------------- - /// @brief Sets the device kernel arguments pointer. - /// - /// @param p_arg The pointer to the Argument we're going to update. - /// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel - /// arguments. - /// - virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const = 0; - - //---------------------------------------------------------------------------------------------- - /// @brief Gets the device kernel argument size. - /// - /// @param[in] p_arg The pointer to the Device op Argument. - /// - /// @return The device kernel argument size. - /// - virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0; }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp index 68c6dcc0f5..0535c80323 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp @@ -18,7 +18,6 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" @@ -78,17 +77,17 @@ template = false> struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage - : public DeviceGroupedGemmMultipleDSplitK + : public DeviceGroupedGemmSplitK { using DeviceOp = DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage; @@ -530,7 +529,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage index_t skipped_group_count_; index_t grid_size_; // Pointer to device memory with GEMM kernel arguments. - const void* p_dev_gemm_args_; + void* p_dev_gemm_kargs_; AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; @@ -566,7 +565,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage /// @return The average kernel execution time (if time measurement is enabled.) /// float Run(const Argument& arg, - const void* dev_gemm_args, + void* dev_gemm_args, void* dev_gemm_workspace, const StreamConfig& stream_config = StreamConfig{}) { @@ -621,7 +620,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage /// float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - if(arg.p_dev_gemm_args_ == nullptr) + if(arg.p_dev_gemm_kargs_ == nullptr) { std::ostringstream err; err << "The gemm arguments device buffer is not allocated!" @@ -637,7 +636,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage throw std::runtime_error(err.str()); } - return Run(arg, arg.p_dev_gemm_args_, arg.p_workspace_, stream_config); + return Run(arg, arg.p_dev_gemm_kargs_, arg.p_workspace_, stream_config); } float Run(const BaseArgument* p_arg, @@ -723,7 +722,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage template float DispatchKernel(const Argument& arg, - const void* dev_gemm_args, + void* dev_gemm_kargs, void* dev_gemm_workspace, const StreamConfig& stream_config) const { @@ -746,7 +745,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage return LaunchKernel(gemm_kernel, elementwise_kernel, arg, - dev_gemm_args, + dev_gemm_kargs, dev_gemm_workspace, stream_config); } @@ -755,12 +754,19 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage float LaunchKernel(const KernelFunction& gemm_kernel, const KernelFunction2& elementwise_kernel, const Argument& arg, - const void* dev_gemm_args, + void* dev_gemm_kargs, [[maybe_unused]] void* dev_gemm_workspace, const StreamConfig& stream_config) const { float time{0.f}; + hip_check_error( + hipMemcpyWithStream(dev_gemm_kargs, + arg.gemm_kernel_args_.data(), + arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg), + hipMemcpyHostToDevice, + stream_config.stream_id_)); + auto preprocess = [&]() { hip_check_error(hipMemsetAsync( dev_gemm_workspace, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_)); @@ -774,7 +780,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage dim3(arg.grid_size_), dim3(BlockSize), 0, - cast_pointer_to_constant_address_space(dev_gemm_args), + cast_pointer_to_constant_address_space(dev_gemm_kargs), arg.gemm_kernel_args_.size(), arg.a_element_op_, arg.b_element_op_, @@ -930,18 +936,30 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage return str.str(); } - void SetDeviceKernelArgs(Argument& arg, void* p_dev_kernel_args) const - { - arg.p_dev_gemm_args_ = p_dev_kernel_args; - hip_check_error(hipMemcpy(p_dev_kernel_args, - arg.gemm_kernel_args_.data(), - GetDeviceKernelArgSize(&arg), - hipMemcpyHostToDevice)); - } - void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override { - return SetDeviceKernelArgs(*dynamic_cast(p_arg), p_dev_kernel_args); + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + arg_ptr->p_dev_gemm_kargs_ = p_dev_kernel_args; + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->gemm_kernel_args_.size() * sizeof(GemmTransKernelArg); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"); } size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override @@ -974,17 +992,22 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"); } - static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); } + [[deprecated]] static void SetKBatchSize(Argument& arg, index_t kbatch) + { + arg.UpdateKBatch(kbatch); + } void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override { - return SetKBatchSize(*dynamic_cast(p_arg), kbatch); - } - - size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override - { - return dynamic_cast(p_arg)->gemm_kernel_args_.size() * - sizeof(GemmTransKernelArg); + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->UpdateKBatch(kbatch); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"); } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp index 2884e558cd..f673713f3e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp @@ -20,7 +20,6 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" // stare wywalic -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" namespace ck { @@ -522,7 +521,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop ComputeTypeA, ComputeTypeB>; - using KernelArguments = GroupedGemmTileLoopKernelArguments; + using KernelArguments = GroupedGemmKernelArgument; using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; using OffsettedLocalBlock2ETileMap = OffsettedBlockToCTileMap2; @@ -936,12 +935,31 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop return str.str(); } + void SetDeviceKernelArgs(Argument& arg, + void* p_dev_kernel_args, + const void* p_host_kernel_args) const + { + arg.p_dev_gemm_args_ = p_dev_kernel_args; + hip_check_error(hipMemcpy(p_dev_kernel_args, + p_host_kernel_args, + GetDeviceKernelArgSize(&arg), + hipMemcpyHostToDevice)); + } + + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, + void* p_dev_kernel_args, + const void* p_host_kernel_args) const override + { + return SetDeviceKernelArgs( + *dynamic_cast(p_arg), p_dev_kernel_args, p_host_kernel_args); + } + void SetDeviceKernelArgs(Argument& arg, void* p_dev_kernel_args) const { arg.p_dev_gemm_args_ = p_dev_kernel_args; } - void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override { return SetDeviceKernelArgs(*dynamic_cast(p_arg), p_dev_kernel_args); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp index 658f323516..86cf1da156 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp @@ -1,6 +1,6 @@ #pragma once // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -717,7 +717,24 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm(p_arg)->group_count_ * sizeof(GemmBiasTransKernelArg); + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + return p_arg_->group_count_ * sizeof(GemmBiasTransKernelArg); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemmMultipleDXdlCShuffle::Argument structure!"); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + return GetWorkSpaceSize(p_arg); + } + + void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override + { + return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args); } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp index ac05a0703f..1fee02bad8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp @@ -445,6 +445,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK; using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops; + // TODO: replace with GroupedGemmKernelArgument struct GemmBiasTransKernelArg { // pointers @@ -900,40 +901,58 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK(p_arg), kernel_args); + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + arg_ptr->grouped_gemm_kernel_args_dev = kernel_args; + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"); } size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override { - auto arg = *dynamic_cast(p_arg); - - return arg.group_count_ * arg.barrier_size_grp_ * sizeof(uint32_t); + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + return arg_ptr->group_count_ * arg_ptr->barrier_size_grp_ * sizeof(uint32_t); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"); } size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override { - auto arg = *dynamic_cast(p_arg); - - return arg.group_count_ * sizeof(GroupedGemmKernelArgument); + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + return arg_ptr->group_count_ * sizeof(GroupedGemmKernelArgument); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"); } void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace, const StreamConfig& stream_config = StreamConfig{}) const override { - auto p_arg_ = dynamic_cast(p_arg); - p_arg_->p_workspace_ = p_workspace; + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + arg_ptr->p_workspace_ = p_workspace; + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"); hip_check_error( - hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(p_arg), stream_config.stream_id_)); + hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(arg_ptr), stream_config.stream_id_)); } static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } @@ -941,7 +960,26 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK(p_arg), k_batch); + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + arg_ptr->UpdateKBatch(k_batch); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"); + } + + void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override + { + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + arg_ptr->UpdateKBatch(kbatch); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"); } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp index cb0afbb08d..626ffbe979 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp @@ -546,7 +546,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK(p_arg)->gemm_kernel_args_.size() * - sizeof(GemmTransKernelArg); + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + return p_arg_->gemm_kernel_args_.size() * sizeof(GemmTransKernelArg); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedGemmMultipleDSplitKXdlCShuffle::Argument structure!"); } + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + return GetWorkSpaceSize(p_arg); + } + + // TODO: deperecation notice. static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); } // polymorphic void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override { - return SetKBatchSize(*dynamic_cast(p_arg), kbatch); + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->UpdateKBatch(kbatch); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedGemmMultipleDSplitKXdlCShuffle::Argument structure!"); + } + + void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override + { + return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args); } }; diff --git a/include/ck/utility/loop_scheduler.hpp b/include/ck/utility/loop_scheduler.hpp index 0c4d85bedb..a88109249d 100644 --- a/include/ck/utility/loop_scheduler.hpp +++ b/include/ck/utility/loop_scheduler.hpp @@ -5,7 +5,6 @@ #pragma once #include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_adaptor.hpp" namespace ck { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp index 87426fd52e..a999f9e3a0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp @@ -95,6 +95,45 @@ void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv1_inter_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv1_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv2_instances( + std::vector>>& instances); + void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances( std::vector>>& instances); + +void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv1_inter_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv1_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv2_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv1_inter_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv1_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv2_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv1_inter_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv1_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv2_instances( + std::vector>>& instances); + #endif #if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8) @@ -262,7 +419,11 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_bf16_bf16_mk_nk_mn_instances( op_ptrs); + add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv1_inter_instances( + op_ptrs); + add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv1_instances( + op_ptrs); + add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv2_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv1_inter_instances( + op_ptrs); + add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv1_instances( + op_ptrs); + add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv2_instances( + op_ptrs); } } #endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp new file mode 100644 index 0000000000..7721e42c3c --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/utility/loop_scheduler.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto PipelineV1 = ck::PipelineVersion::v1; +static constexpr auto PipelineV2 = ck::PipelineVersion::v2; +static constexpr auto DefaultScheduler = ck::LoopScheduler::Default; +static constexpr auto InterwaveScheduler = ck::LoopScheduler::Interwave; +static constexpr auto GemmMNKPadding = device::GemmSpecialization::MNKPadding; +static constexpr auto GemmDefault = device::GemmSpecialization::Default; + +template = false> +using device_grouped_gemm_xdl_splitk_2Bt_rrr_instances = std::tuple< + // clang-format off + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline | Loop | + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Version | Scheduler | + //################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Pipeline, Scheduler> + // clang-format on + >; + +template = false> +using device_grouped_gemm_xdl_splitk_2Bt_rcr_instances = std::tuple< + // clang-format off + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline | Loop | + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Version | Scheduler | + //################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Pipeline, Scheduler> + // clang-format on + >; + +template = false> +using device_grouped_gemm_xdl_splitk_2Bt_crr_instances = std::tuple< + // clang-format off + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline | Loop | + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Version | Scheduler | + //################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 2, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Pipeline, Scheduler>, + DeviceGroupedGemmXdlSplitKCShuffle< Col, Row, Empty_Tuple, Row, T, T, F32, T, Empty_Tuple, T, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Pipeline, Scheduler> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt index de20321945..4a3e1a4ada 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt @@ -4,12 +4,30 @@ add_instance_library(device_grouped_gemm_instance device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp - device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp - device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp + + device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv1_inter.cpp + device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv1.cpp + device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv2.cpp + + device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv1_inter.cpp + device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv1.cpp + device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv2.cpp + + device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv1_inter.cpp + device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv1.cpp + device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv2.cpp + + device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv1_inter.cpp + device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv1.cpp + device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv2.cpp + device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instance.cpp + device_grouped_gemm_multiple_d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_bf16_bf16_mk_kn_mn_instance.cpp device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_bf16_bf16_mk_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv1.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv1.cpp new file mode 100644 index 0000000000..b8a03871cd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv1.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv1_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_gemm_xdl_splitk_2Bt_crr_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv1_inter.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv1_inter.cpp new file mode 100644 index 0000000000..10141165ca --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv1_inter.cpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv1_inter_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_splitk_2Bt_crr_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv2.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv2.cpp new file mode 100644 index 0000000000..b96f5983ce --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv2.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_km_kn_mn_irregular_pv2_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_splitk_2Bt_crr_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv1.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv1.cpp new file mode 100644 index 0000000000..8fad42316e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv1.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv1_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_gemm_xdl_splitk_2Bt_rrr_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv1_inter.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv1_inter.cpp new file mode 100644 index 0000000000..7845136ca6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv1_inter.cpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv1_inter_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_splitk_2Bt_rrr_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv2.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv2.cpp new file mode 100644 index 0000000000..a2d79edf6b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv2.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_kn_mn_irregular_pv2_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_splitk_2Bt_rrr_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv1.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv1.cpp new file mode 100644 index 0000000000..033a2929f0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv1.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv1_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_gemm_xdl_splitk_2Bt_rcr_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv1_inter.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv1_inter.cpp new file mode 100644 index 0000000000..cf8c94bf46 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv1_inter.cpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv1_inter_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_splitk_2Bt_rcr_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv2.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv2.cpp new file mode 100644 index 0000000000..70c0d703ef --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv2.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_splitk_bf16_bf16_bf16_mk_nk_mn_irregular_pv2_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_splitk_2Bt_rcr_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp index 98e476f8bb..077a8a18ca 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp @@ -1,53 +1,14 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp" +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using Empty_Tuple = ck::Tuple<>; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// a[m, k] * b[k, n] = e[m, n] -using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple< - // clang-format off - //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> - // clang-format on - >; - void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances( std::vector>>& instances) { - add_device_operation_instances(instances, - device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, device_grouped_gemm_xdl_splitk_2Bt_rrr_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp deleted file mode 100644 index ed0a8c7b70..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp +++ /dev/null @@ -1,123 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using Empty_Tuple = ck::Tuple<>; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_tile_instances = std::tuple< - // clang-format off - //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v1>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v1>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v1>, - - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v1, LoopScheduler::Interwave>, - - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v2>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v2>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v2>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v2>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v2>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v2>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, PipelineVersion::v2>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v2>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v2>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v2>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v2>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v2>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v2>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v2>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v2>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, PipelineVersion::v2>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, PipelineVersion::v2>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v2>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v2>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, PipelineVersion::v2> - // clang-format on - >; - -void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_tile_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv1.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv1.cpp new file mode 100644 index 0000000000..8ad4736ac4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv1.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv1_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_gemm_xdl_splitk_2Bt_rrr_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv1_inter.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv1_inter.cpp new file mode 100644 index 0000000000..1d968c8210 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv1_inter.cpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv1_inter_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_splitk_2Bt_rrr_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv2.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv2.cpp new file mode 100644 index 0000000000..ee3d7d73b8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv2.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_pv2_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_splitk_2Bt_rrr_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp index aa6365cd98..085e74f0ca 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp @@ -1,57 +1,14 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp" +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using Empty_Tuple = ck::Tuple<>; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// a[m, k] * b[n, k] = e[m, n] -using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple< - // clang-format off - //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances( std::vector>>& instances) { - add_device_operation_instances(instances, - device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances{}); + add_device_operation_instances( + instances, device_grouped_gemm_xdl_splitk_2Bt_rcr_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp index f4460b360b..320bb933b9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp @@ -1,63 +1,14 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp" +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_splitk_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using Empty_Tuple = ck::Tuple<>; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple< - // clang-format off - //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 32, 8, 8, 32, 32, 1, 4, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedGemmXdlSplitKCShuffle< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances( std::vector>>& instances) { add_device_operation_instances( - instances, device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instances{}); + instances, device_grouped_gemm_xdl_splitk_2Bt_rcr_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instance.cpp deleted file mode 100644 index c98328e52d..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instance.cpp +++ /dev/null @@ -1,234 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using BF16 = ck::bhalf_t; -using I8 = int8_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using Multiply = ck::tensor_operation::element_wise::Multiply; -using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu; -using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu; -using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd; - -static constexpr auto GemmDefault = GemmSpecialization::Default; -static constexpr auto GemmKPadding = GemmSpecialization::KPadding; -static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; -static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; - -static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; -static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; - -template -using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances = std::tuple< - // clang-format off - //###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, - // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, - // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, - // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 224, 256, 64, 8, 4, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 2, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8,8,1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> - - // clang-format on - >; - -template -using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances = - std::tuple< - // clang-format off - //###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - // Memory friendly - // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 64, 128, 8, 4, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> - // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 32, 64, 128, 8, 4, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> - // clang-format on - >; - -void add_device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn_instances( - std::vector, - Row, - BF16, - I8, - ck::Tuple, - BF16, - PassThrough, - PassThrough, - Multiply>>>& instances) -{ - // comp - add_device_operation_instances( - instances, - device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances, - ck::Tuple, - Multiply, - GemmDefault>{}); - add_device_operation_instances( - instances, - device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances, - ck::Tuple, - Multiply, - GemmMNKPadding>{}); - add_device_operation_instances( - instances, - device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances, - ck::Tuple, - Multiply, - GemmMNPadding>{}); - add_device_operation_instances( - instances, - device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_comp_instances, - ck::Tuple, - Multiply, - GemmKPadding>{}); - // mem - add_device_operation_instances( - instances, - device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, - ck::Tuple, - Multiply, - GemmDefault, - Intrawave>{}); - add_device_operation_instances( - instances, - device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, - ck::Tuple, - Multiply, - GemmMNKPadding, - Intrawave>{}); - add_device_operation_instances( - instances, - device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, - ck::Tuple, - Multiply, - GemmMNPadding, - Intrawave>{}); - add_device_operation_instances( - instances, - device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, - ck::Tuple, - Multiply, - GemmKPadding, - Intrawave>{}); - - add_device_operation_instances( - instances, - device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, - ck::Tuple, - Multiply, - GemmDefault, - Interwave>{}); - add_device_operation_instances( - instances, - device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, - ck::Tuple, - Multiply, - GemmMNKPadding, - Interwave>{}); - add_device_operation_instances( - instances, - device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, - ck::Tuple, - Multiply, - GemmMNPadding, - Interwave>{}); - add_device_operation_instances( - instances, - device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances, - ck::Tuple, - Multiply, - GemmKPadding, - Interwave>{}); -} - -void add_device_grouped_gemm_xdl_tile_loop_multiply_bias_fastgelu_bf16_i8_bf16_mk_kn_mn_instances( - std::vector, - Row, - BF16, - I8, - ck::Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyAddFastGelu>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_irregular_tile_instances< - ck::Tuple, - ck::Tuple, - MultiplyAddFastGelu>{}); -} - -void add_device_grouped_gemm_xdl_tile_loop_multiply_fastgelu_bf16_i8_bf16_mk_kn_mn_instances( - std::vector, - Row, - BF16, - I8, - ck::Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyFastGelu>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_irregular_tile_instances< - ck::Tuple, - ck::Tuple, - MultiplyFastGelu>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_gemm_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_impl.hpp index 0b73e4fcd1..c10cd0ea9f 100644 --- a/profiler/include/profiler/profile_grouped_gemm_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -17,7 +17,6 @@ #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" #include "ck/library/utility/fill.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" @@ -42,11 +41,14 @@ bool profile_grouped_gemm_impl(int do_verification, const std::vector& StrideAs, const std::vector& StrideBs, const std::vector& StrideCs, - int kbatch = 1, - int n_warmup = 1, - int n_iter = 10) + const std::vector& kbatches = {}, + int n_warmup = 1, + int n_iter = 10) { bool pass = true; + // TODO: Fixme - we do not pass compute data type here but need it + // to compute error thresholds. + using ComputeDataType = ADataType; auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { @@ -75,6 +77,7 @@ bool profile_grouped_gemm_impl(int do_verification, std::vector> c_m_n_host_results; std::vector> c_m_n_device_results; + ComputeDataType max_abs_in_val = 0.f; for(std::size_t i = 0; i < group_count; i++) { a_m_k.push_back( @@ -93,17 +96,18 @@ bool profile_grouped_gemm_impl(int do_verification, << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i << "]:" << c_m_n_device_results[i].mDesc << std::endl; } - std::size_t num_thread = 1; switch(init_method) { case 0: break; case 1: - a_m_k[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - b_k_n[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(a_m_k[i]); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(b_k_n[i]); + max_abs_in_val = 2.f; break; default: - a_m_k[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - b_k_n[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + ck::utils::FillUniformDistribution{-0.5f, 0.5f}(a_m_k[i]); + ck::utils::FillUniformDistribution{-0.5f, 0.5f}(b_k_n[i]); + max_abs_in_val = 0.5f; } } @@ -164,7 +168,20 @@ bool profile_grouped_gemm_impl(int do_verification, BElementOp, CElementOp>; - const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + // If kbatch would be bigger than 1, then we will use SplitK version. + using DeviceOpSplitK = ck::tensor_operation::device::DeviceGroupedGemmSplitK, + CLayout, + ADataType, + BDataType, + ck::Tuple<>, + CDataType, + AElementOp, + BElementOp, + CElementOp>; + + auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); if(op_ptrs.size() <= 0) @@ -205,7 +222,6 @@ bool profile_grouped_gemm_impl(int do_verification, ref_invoker.Run(ref_argument); } } - // profile device GEMM instances for(auto& gemm_ptr : op_ptrs) { @@ -221,43 +237,44 @@ bool profile_grouped_gemm_impl(int do_verification, auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); - DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get())); + std::size_t workspace_size = gemm_ptr->GetWorkSpaceSize(argument_ptr.get()); + std::size_t kargs_size = gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()); - gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer()); - std::string gemm_name = gemm_ptr->GetTypeString(); + DeviceMem gemm_workspace, gemm_kargs; - using DeviceOpSplitK = ck::tensor_operation::device::DeviceGroupedGemmSplitK, - CLayout, - ADataType, - BDataType, - ck::Tuple<>, - CDataType, - AElementOp, - BElementOp, - CElementOp>; - - // skip non-splitk grouped_gemm - if(dynamic_cast(gemm_ptr.get()) == nullptr) + // The following is necessary since TwoStage kernel is using additional memory both + // for Workspace and kernel arguments. + if(kargs_size > 0) { - continue; + gemm_kargs.Realloc(kargs_size); + gemm_ptr->SetDeviceKernelArgs(argument_ptr.get(), gemm_kargs.GetDeviceBuffer()); } + if(workspace_size > 0 && workspace_size != kargs_size) + { + gemm_workspace.Realloc(workspace_size); + gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_workspace.GetDeviceBuffer()); + } + + std::string gemm_name = gemm_ptr->GetTypeString(); std::vector kbatch_list = {1, 2, 4, 8, 12, 16, 20, 24, 32, 48, 64}; - if(kbatch > 0) + // If the user will provide not empty kbatches list, then we test predefined set of kbatch + // values. + if(!kbatches.empty()) { - kbatch_list = {kbatch}; + kbatch_list = kbatches; } for(std::size_t j = 0; j < kbatch_list.size(); j++) { - auto kbatch_curr = kbatch_list[j]; - dynamic_cast(gemm_ptr.get()) - ->SetKBatchSize(argument_ptr.get(), kbatch_curr); + if(kbatch_curr > 1 && dynamic_cast(gemm_ptr.get()) != nullptr) + { + dynamic_cast(gemm_ptr.get()) + ->SetKBatchSize(argument_ptr.get(), kbatch_curr); + } if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) { @@ -272,23 +289,18 @@ bool profile_grouped_gemm_impl(int do_verification, bool instance_pass = true; for(std::size_t i = 0; i < gemm_descs.size(); i++) { - c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data()); + auto atol = ck::utils::get_absolute_threshold( + max_abs_in_val, gemm_descs[i].K_); + auto rtol = ck::utils::get_relative_threshold( + gemm_descs[i].K_); - if(std::is_same_v && kbatch_curr > 1) - { - instance_pass = - instance_pass && ck::utils::check_err(c_m_n_device_results[i], - c_m_n_host_results[i], - "Error: Incorrect results!", - 0.06); - } - else - { - instance_pass = - instance_pass && ck::utils::check_err(c_m_n_device_results[i], - c_m_n_host_results[i]); - } + instance_pass = + instance_pass && ck::utils::check_err(c_m_n_device_results[i], + c_m_n_host_results[i], + "Error: Incorrect results!", + rtol, + atol); if(do_log) { @@ -311,11 +323,12 @@ bool profile_grouped_gemm_impl(int do_verification, pass = pass && instance_pass; } - float ave_time = invoker_ptr->Run( - argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter}); - if(time_kernel) { + float ave_time = + invoker_ptr->Run(argument_ptr.get(), + StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter}); + std::size_t flop = 0, num_btype = 0; for(std::size_t i = 0; i < gemm_descs.size(); i++) { diff --git a/profiler/include/profiler/profile_grouped_gemm_multiply_tile_loop_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_multiply_tile_loop_impl.hpp index f665644162..94ee2a37e4 100644 --- a/profiler/include/profiler/profile_grouped_gemm_multiply_tile_loop_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_multiply_tile_loop_impl.hpp @@ -143,8 +143,7 @@ bool profile_grouped_gemm_multiply_tile_loop_impl(int do_verification, p_ds.reserve(group_count); p_e.reserve(group_count); - using KernelArguments = - ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments; + using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument; std::vector gemm_descs; std::vector gemm_kargs; diff --git a/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp index 74faf15be3..3a4ca24dda 100644 --- a/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp @@ -127,7 +127,7 @@ bool profile_grouped_gemm_tile_loop_impl(int do_verification, p_b.reserve(group_count); p_c.reserve(group_count); - using KernelArguments = ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments<>; + using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument<>; std::vector gemm_descs; std::vector gemm_kargs; diff --git a/profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp deleted file mode 100644 index 14df96d505..0000000000 --- a/profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp +++ /dev/null @@ -1,367 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp" - -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/convolution_parameter.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/utility/fill.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" - -namespace ck { -namespace profiler { - -template -bool profile_grouped_gemm_two_stage_impl(int do_verification, - int init_method, - bool do_log, - bool time_kernel, - const std::vector& Ms, - const std::vector& Ns, - const std::vector& Ks, - const std::vector& StrideAs, - const std::vector& StrideBs, - const std::vector& StrideCs, - int kbatch = 1, - int n_warmup = 1, - int n_iter = 10) -{ - bool pass = true; - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; - - if(is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - std::size_t group_count = Ms.size(); - - if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() && - group_count == StrideBs.size() && group_count == StrideCs.size())) - { - throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n"); - } - - std::vector> a_m_k; - std::vector> b_k_n; - std::vector> c_m_n_host_results; - std::vector> c_m_n_device_results; - - for(std::size_t i = 0; i < group_count; i++) - { - a_m_k.push_back( - Tensor(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{}))); - b_k_n.push_back( - Tensor(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{}))); - - c_m_n_device_results.push_back( - Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); - - c_m_n_host_results.push_back( - Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" - << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i - << "]:" << c_m_n_device_results[i].mDesc << std::endl; - } - std::size_t num_thread = 1; - switch(init_method) - { - case 0: break; - case 1: - a_m_k[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - b_k_n[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - default: - a_m_k[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - b_k_n[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); - } - } - - using AElementOp = ck::tensor_operation::element_wise::PassThrough; - using BElementOp = ck::tensor_operation::element_wise::PassThrough; - using CElementOp = ck::tensor_operation::element_wise::PassThrough; - - const auto a_element_op = AElementOp{}; - const auto b_element_op = BElementOp{}; - const auto c_element_op = CElementOp{}; - - using DeviceMemPtr = std::unique_ptr; - std::vector a_device_buf, b_device_buf, c_device_buf; - - a_device_buf.reserve(group_count); - b_device_buf.reserve(group_count); - c_device_buf.reserve(group_count); - - std::vector p_a, p_b; - std::vector p_c; - - p_a.reserve(group_count); - p_b.reserve(group_count); - p_c.reserve(group_count); - - std::vector gemm_descs; - - gemm_descs.reserve(group_count); - - for(std::size_t i = 0; i < group_count; i++) - { - a_device_buf.emplace_back( - std::make_unique(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize())); - b_device_buf.emplace_back( - std::make_unique(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize())); - c_device_buf.emplace_back(std::make_unique( - sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpaceSize())); - - a_device_buf[i]->ToDevice(a_m_k[i].mData.data()); - b_device_buf[i]->ToDevice(b_k_n[i].mData.data()); - - gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); - - p_a.push_back(a_device_buf[i]->GetDeviceBuffer()); - p_b.push_back(b_device_buf[i]->GetDeviceBuffer()); - p_c.push_back(c_device_buf[i]->GetDeviceBuffer()); - } - - using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemm, - CLayout, - ADataType, - BDataType, - ck::Tuple<>, - CDataType, - AElementOp, - BElementOp, - CElementOp>; - - const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< - DeviceOp>::GetInstances(); - - if(op_ptrs.size() <= 0) - { - throw std::runtime_error("wrong! no device GEMM instance found"); - } - - std::string best_gemm_name; - float best_ave_time = 0; - float best_tflops = 0; - float best_gb_per_sec = 0; - float best_kbatch = 0; - - auto p_ds = std::vector>{}; - - if(do_verification) - { - for(std::size_t i = 0; i < gemm_descs.size(); i++) - { - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; - - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument(a_m_k[i], - b_k_n[i], - c_m_n_host_results[i], - a_element_op, - b_element_op, - c_element_op); - - ref_invoker.Run(ref_argument); - } - } - - // profile device GEMM instances - for(auto& gemm_ptr : op_ptrs) - { - auto argument_ptr = - gemm_ptr->MakeArgumentPointer(p_a, - p_b, - p_ds, - p_c, - gemm_descs, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}); - - auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); - - DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get())); - gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer()); - - std::string gemm_name = gemm_ptr->GetTypeString(); - - using DeviceOpSplitK = - ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitK, - CLayout, - ADataType, - BDataType, - ck::Tuple<>, - CDataType, - AElementOp, - BElementOp, - CElementOp>; - - // skip non-splitk grouped_gemm - if(dynamic_cast(gemm_ptr.get()) == nullptr) - { - continue; - } - - std::vector kbatch_list = {1, 2, 4, 8, 12, 16, 20, 24, 32, 48, 64}; - - if(kbatch > 0) - { - kbatch_list = {kbatch}; - } - - for(std::size_t j = 0; j < kbatch_list.size(); j++) - { - - auto kbatch_curr = kbatch_list[j]; - dynamic_cast(gemm_ptr.get()) - ->SetKBatchSize(argument_ptr.get(), kbatch_curr); - - DeviceMem gemm_arg_dev_mem(dynamic_cast(gemm_ptr.get()) - ->GetDeviceKernelArgSize(argument_ptr.get())); - dynamic_cast(gemm_ptr.get()) - ->SetDeviceKernelArgs(argument_ptr.get(), gemm_arg_dev_mem.GetDeviceBuffer()); - - if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) - { - gemm_desc_workspace.SetZero(); - for(std::size_t i = 0; i < gemm_descs.size(); i++) - c_device_buf[i]->SetZero(); - - invoker_ptr->Run(argument_ptr.get(), - StreamConfig{nullptr, false, 0, n_warmup, n_iter}); - if(do_verification) - { - bool instance_pass = true; - for(std::size_t i = 0; i < gemm_descs.size(); i++) - { - c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data()); - if(std::is_same_v && kbatch_curr > 1) - { - instance_pass = - instance_pass && ck::utils::check_err(c_m_n_device_results[i], - c_m_n_host_results[i], - "Error: Incorrect results!", - 0.06); - } - else - { - instance_pass = - instance_pass && ck::utils::check_err(c_m_n_device_results[i], - c_m_n_host_results[i]); - } - - if(do_log) - { - LogRangeAsType(std::cout << "a : ", a_m_k[i].mData, ",") - << std::endl; - LogRangeAsType(std::cout << "b: ", b_k_n[i].mData, ",") - << std::endl; - LogRangeAsType( - std::cout << "c_device: ", c_m_n_device_results[i].mData, ",") - << std::endl; - LogRangeAsType( - std::cout << "c_host : ", c_m_n_host_results[i].mData, ",") - << std::endl; - } - } - - std::cout << "Instance: " << gemm_name << " verification " - << (instance_pass ? "SUCCEED" : "FAILED") << std::endl; - - pass = pass && instance_pass; - } - float ave_time = invoker_ptr->Run( - argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter}); - if(time_kernel) - { - std::size_t flop = 0, num_btype = 0; - for(std::size_t i = 0; i < gemm_descs.size(); i++) - { - flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i]; - - num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + - sizeof(BDataType) * Ks[i] * Ns[i] + - sizeof(CDataType) * Ms[i] * Ns[i]; - } - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops - << " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << ", KBatch " - << kbatch_curr << std::endl; - - if(tflops > best_tflops) - { - best_gemm_name = gemm_name; - best_tflops = tflops; - best_ave_time = ave_time; - best_gb_per_sec = gb_per_sec; - best_kbatch = kbatch_curr; - } - } - } - else - { - std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem" - << std::endl; - } - } - } - - if(time_kernel) - { - std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " - << best_gb_per_sec << " GB/s, " << best_gemm_name << ", KBatch = " << best_kbatch - << std::endl; - } - - return pass; -} - -} // namespace profiler -} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index f079d554bf..35e91f8172 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -43,7 +43,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp) list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm_two_stage.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_tile_loop.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp) diff --git a/profiler/src/profile_grouped_gemm.cpp b/profiler/src/profile_grouped_gemm.cpp index fbf44d720f..2adcd6483a 100644 --- a/profiler/src/profile_grouped_gemm.cpp +++ b/profiler/src/profile_grouped_gemm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -39,16 +39,13 @@ namespace { std::vector argToIntArray(char* input) { std::vector out; - std::istringstream in(input); - std::string item; while(std::getline(in, item, ',')) { out.push_back(std::stoi(item)); } - return out; } @@ -69,7 +66,7 @@ int profile_grouped_gemm(int argc, char* argv[]) << "arg7: time kernel (0=n0, 1=yes)\n" << "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " "64,64 64,64 128,128)\n" - << "arg15: kbatch value (default 1)\n" + << "arg15: kbatch values (default 1)\n" << "optional:\n" << "arg16: number of warm-up cycles (default 1)\n" << "arg17: number of iterations (default 10)\n" @@ -92,7 +89,7 @@ int profile_grouped_gemm(int argc, char* argv[]) const auto StrideAs = argToIntArray(argv[11]); const auto StrideBs = argToIntArray(argv[12]); const auto StrideCs = argToIntArray(argv[13]); - const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1; + const auto kbatches = argc >= 15 ? argToIntArray(argv[14]) : std::vector{}; int n_warmup = 1; int n_iter = 10; @@ -102,7 +99,6 @@ int profile_grouped_gemm(int argc, char* argv[]) n_iter = std::stoi(argv[16]); } -#ifdef CK_ENABLE_FP16 if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) { ck::profiler::profile_grouped_gemm_impl(do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatches, + n_warmup, + n_iter); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_grouped_gemm_impl(do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatches, + n_warmup, + n_iter); + } + else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_grouped_gemm_impl(do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatches, n_warmup, n_iter); } @@ -239,7 +301,6 @@ int profile_grouped_gemm(int argc, char* argv[]) { throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); } -#endif return 0; } diff --git a/profiler/src/profile_grouped_gemm_fixed_nk.cpp b/profiler/src/profile_grouped_gemm_fixed_nk.cpp index de90a33ef4..e33d798504 100644 --- a/profiler/src/profile_grouped_gemm_fixed_nk.cpp +++ b/profiler/src/profile_grouped_gemm_fixed_nk.cpp @@ -32,9 +32,7 @@ namespace { std::vector argToIntArray(char* input) { std::vector out; - std::istringstream in(input); - std::string item; while(std::getline(in, item, ',')) @@ -83,7 +81,7 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) const auto StrideAs = argToIntArray(argv[11]); const auto StrideBs = argToIntArray(argv[12]); const auto StrideCs = argToIntArray(argv[13]); - const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1; + const int kbatch = argc >= 15 ? std::stoi(argv[14]) : 1; using F32 = float; using F16 = ck::half_t; @@ -97,8 +95,8 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) int n_iter = 10; if(argc == 17) { - n_warmup = std::stoi(argv[16]); - n_iter = std::stoi(argv[17]); + n_warmup = std::stoi(argv[15]); + n_iter = std::stoi(argv[16]); } #if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8) diff --git a/profiler/src/profile_grouped_gemm_two_stage.cpp b/profiler/src/profile_grouped_gemm_two_stage.cpp deleted file mode 100644 index db37a0b762..0000000000 --- a/profiler/src/profile_grouped_gemm_two_stage.cpp +++ /dev/null @@ -1,228 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include -#include - -#include "profiler/profile_grouped_gemm_two_stage_impl.hpp" -#include "profiler_operation_registry.hpp" - -enum struct GemmMatrixLayout -{ - MK_KN_MN, // 0 - MK_NK_MN, // 1 -}; - -enum struct GemmDataType -{ - F16_F16_F16, // 0 - BF16_INT8_BF16, // 1 - BF16_BF16_BF16 // 2 -}; - -#define OP_NAME "grouped_gemm_two_stage" -#define OP_DESC "Grouped GEMM TwoStage" - -namespace { - -std::vector argToIntArray(char* input) -{ - std::vector out; - - std::istringstream in(input); - - std::string item; - - while(std::getline(in, item, ',')) - { - out.push_back(std::stoi(item)); - } - - return out; -} - -int profile_grouped_gemm_two_stage(int argc, char* argv[]) -{ - if(argc < 14) - { - std::cout - << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" - << "arg2: data type (0: fp16; 1: bf16@int8; 2: bf16)\n" - << "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n]);\n" - << "arg4: verification (0: no; 1: yes)\n" - << "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n" - << "arg6: print tensor value (0: no; 1: yes)\n" - << "arg7: time kernel (0=n0, 1=yes)\n" - << "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " - "64,64 64,64 128,128)\n" - << "arg15: kbatch value (default 1)\n" - << "optional:\n" - << "arg16: number of warm-up cycles (default 1)\n" - << "arg17: number of iterations (default 10)\n" - << std::endl; - - exit(1); - } - - const auto data_type = static_cast(std::stoi(argv[2])); - const auto layout = static_cast(std::stoi(argv[3])); - const bool do_verification = std::stoi(argv[4]); - const int init_method = std::stoi(argv[5]); - const bool do_log = std::stoi(argv[6]); - const bool time_kernel = std::stoi(argv[7]); - - const auto Ms = argToIntArray(argv[8]); - const auto Ns = argToIntArray(argv[9]); - const auto Ks = argToIntArray(argv[10]); - - auto StrideAs = argToIntArray(argv[11]); - auto StrideBs = argToIntArray(argv[12]); - auto StrideCs = argToIntArray(argv[13]); - const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1; - - const int DefaultStrideA = Ks[0]; - const int DefaultStrideB = Ns[0]; - const int DefaultStrideC = Ns[0]; - - for(size_t i = 0; i < Ms.size(); ++i) - { - StrideAs[i] = StrideAs[i] == -1 ? DefaultStrideA : StrideAs[i]; - StrideBs[i] = StrideBs[i] == -1 ? DefaultStrideB : StrideBs[i]; - StrideCs[i] = StrideCs[i] == -1 ? DefaultStrideC : StrideCs[i]; - } - - int n_warmup = 1; - int n_iter = 10; - if(argc == 17) - { - n_warmup = std::stoi(argv[16]); - n_iter = std::stoi(argv[17]); - } - - if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) - { - ck::profiler::profile_grouped_gemm_two_stage_impl( - do_verification, - init_method, - do_log, - time_kernel, - Ms, - Ns, - Ks, - StrideAs, - StrideBs, - StrideCs, - kbatch, - n_warmup, - n_iter); - } - else if(data_type == GemmDataType::BF16_INT8_BF16 && layout == GemmMatrixLayout::MK_KN_MN) - { - ck::profiler::profile_grouped_gemm_two_stage_impl( - do_verification, - init_method, - do_log, - time_kernel, - Ms, - Ns, - Ks, - StrideAs, - StrideBs, - StrideCs, - kbatch, - n_warmup, - n_iter); - } - else if(data_type == GemmDataType::BF16_INT8_BF16 && layout == GemmMatrixLayout::MK_NK_MN) - { - ck::profiler::profile_grouped_gemm_two_stage_impl( - do_verification, - init_method, - do_log, - time_kernel, - Ms, - Ns, - Ks, - StrideAs, - StrideBs, - StrideCs, - kbatch, - n_warmup, - n_iter); - } - else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN) - { - ck::profiler::profile_grouped_gemm_two_stage_impl( - do_verification, - init_method, - do_log, - time_kernel, - Ms, - Ns, - Ks, - StrideAs, - StrideBs, - StrideCs, - kbatch, - n_warmup, - n_iter); - } - else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN) - { - ck::profiler::profile_grouped_gemm_two_stage_impl( - do_verification, - init_method, - do_log, - time_kernel, - Ms, - Ns, - Ks, - StrideAs, - StrideBs, - StrideCs, - kbatch, - n_warmup, - n_iter); - } - else - { - throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); - } - return 0; -} - -} // anonymous namespace - -REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_grouped_gemm_two_stage); diff --git a/test/grouped_gemm/CMakeLists.txt b/test/grouped_gemm/CMakeLists.txt index 55cb209772..f47685cf91 100644 --- a/test/grouped_gemm/CMakeLists.txt +++ b/test/grouped_gemm/CMakeLists.txt @@ -6,12 +6,6 @@ if(result EQUAL 0) add_dependencies(test_grouped_gemm test_grouped_gemm_splitk) endif() -add_gtest_executable(test_grouped_gemm_two_stage_splitk test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp) -if(result EQUAL 0) - target_link_libraries(test_grouped_gemm_two_stage_splitk PRIVATE utility device_grouped_gemm_instance) - add_dependencies(test_grouped_gemm test_grouped_gemm_two_stage_splitk) -endif() - add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp) if(result EQUAL 0) target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance) diff --git a/test/grouped_gemm/test_grouped_gemm_splitk_xdl.cpp b/test/grouped_gemm/test_grouped_gemm_splitk_xdl.cpp index d9282fa924..74d49eb576 100644 --- a/test/grouped_gemm/test_grouped_gemm_splitk_xdl.cpp +++ b/test/grouped_gemm/test_grouped_gemm_splitk_xdl.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -10,25 +10,35 @@ #include "gtest/gtest.h" #include "test_grouped_gemm_util.hpp" -using F16 = ck::half_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F8 = ck::f8_t; +using I8 = int8_t; + using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; -using RRR_F16_F16_F16 = ck::test::TestGroupedGemm>; -using RCR_F16_F16_F16 = ck::test::TestGroupedGemm>; +template +class TestGroupedGemm : public ck::test::TestGroupedGemm +{ +}; -using RRR_F16_F16_F16_LargeK = ck::test::TestGroupedGemm>; -using RCR_F16_F16_F16_LargeK = ck::test::TestGroupedGemm>; +// clang-format off +using KernelTypes = ::testing::Types< + std::tuple< Row, Row, Row, F16, F16, F16>, + std::tuple< Row, Col, Row, F16, F16, F16>, + std::tuple< Col, Row, Row, F16, F16, F16>, + std::tuple< Col, Col, Row, F16, F16, F16>, + std::tuple< Row, Row, Row, BF16, BF16, BF16>, + std::tuple< Row, Col, Row, BF16, BF16, BF16>, + std::tuple< Col, Row, Row, BF16, BF16, BF16>, + std::tuple< Row, Row, Row, BF16, I8, BF16>, + std::tuple< Row, Col, Row, BF16, I8, BF16>, + std::tuple< Row, Row, Row, F16, F8, F16>, + std::tuple< Row, Row, Row, F8, F16, F16> + >; +// clang-format on -const std::vector KBATCH{1, 2, 3, 5, 8}; - -INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_KN, RRR_F16_F16_F16, testing::ValuesIn(KBATCH)); -INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_MK_NK, RCR_F16_F16_F16, testing::ValuesIn(KBATCH)); -INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_LargeK_MK_KN, - RRR_F16_F16_F16_LargeK, - testing::Values(32, 64)); -INSTANTIATE_TEST_SUITE_P(TestGroupedGemm_splitk_LargeK_MK_NK, - RCR_F16_F16_F16_LargeK, - testing::Values(32, 64)); +TYPED_TEST_SUITE(TestGroupedGemm, KernelTypes); #include "test_grouped_gemm_ut_cases.inc" diff --git a/test/grouped_gemm/test_grouped_gemm_ut_cases.inc b/test/grouped_gemm/test_grouped_gemm_ut_cases.inc index d94d140d97..f4011cf998 100644 --- a/test/grouped_gemm/test_grouped_gemm_ut_cases.inc +++ b/test/grouped_gemm/test_grouped_gemm_ut_cases.inc @@ -1,6 +1,6 @@ #pragma once -TEST_P(RRR_F16_F16_F16, TinyCases) +TYPED_TEST(TestGroupedGemm, TinyCases) { const std::vector Ms{0, 1}; constexpr int N = 768; @@ -8,14 +8,11 @@ TEST_P(RRR_F16_F16_F16, TinyCases) const std::vector Ns(Ms.size(), N); const std::vector Ks(Ms.size(), K); - const std::vector StrideAs(Ms.size(), K); - const std::vector StrideBs(Ms.size(), N); - const std::vector StrideCs(Ms.size(), N); - this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); + this->Run(Ms, Ns, Ks); } -TEST_P(RRR_F16_F16_F16, SmallCases) +TYPED_TEST(TestGroupedGemm, SmallCases) { const std::vector Ms{2, 1, 3, 4, 5, 0}; constexpr int N = 768; @@ -23,14 +20,11 @@ TEST_P(RRR_F16_F16_F16, SmallCases) const std::vector Ns(Ms.size(), N); const std::vector Ks(Ms.size(), K); - const std::vector StrideAs(Ms.size(), K); - const std::vector StrideBs(Ms.size(), N); - const std::vector StrideCs(Ms.size(), N); - this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); + this->Run(Ms, Ns, Ks); } -TEST_P(RRR_F16_F16_F16, MidCases) +TYPED_TEST(TestGroupedGemm, MidCases) { const std::vector Ms{167, 183, 177, 153, 139, 204}; constexpr int N = 768; @@ -38,14 +32,11 @@ TEST_P(RRR_F16_F16_F16, MidCases) const std::vector Ns(Ms.size(), N); const std::vector Ks(Ms.size(), K); - const std::vector StrideAs(Ms.size(), K); - const std::vector StrideBs(Ms.size(), N); - const std::vector StrideCs(Ms.size(), N); - this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); + this->Run(Ms, Ns, Ks); } -TEST_P(RRR_F16_F16_F16, Regular) +TYPED_TEST(TestGroupedGemm, Regular) { const std::vector Ms{64, 128, 256}; constexpr int N = 768; @@ -53,14 +44,11 @@ TEST_P(RRR_F16_F16_F16, Regular) const std::vector Ns(Ms.size(), N); const std::vector Ks(Ms.size(), K); - const std::vector StrideAs(Ms.size(), K); - const std::vector StrideBs(Ms.size(), N); - const std::vector StrideCs(Ms.size(), N); - this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); + this->Run(Ms, Ns, Ks); } -TEST_P(RRR_F16_F16_F16, MNKPadded) +TYPED_TEST(TestGroupedGemm, MNKPadded) { const std::vector Ms{127, 150, 188, 210}; constexpr int N = 136; @@ -68,88 +56,11 @@ TEST_P(RRR_F16_F16_F16, MNKPadded) const std::vector Ns(Ms.size(), N); const std::vector Ks(Ms.size(), K); - const std::vector StrideAs(Ms.size(), K); - const std::vector StrideBs(Ms.size(), N); - const std::vector StrideCs(Ms.size(), N); - this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); + this->Run(Ms, Ns, Ks); } -TEST_P(RCR_F16_F16_F16, TinyCases) -{ - const std::vector Ms{0, 1}; - constexpr int N = 768; - constexpr int K = 544; - - const std::vector Ns(Ms.size(), N); - const std::vector Ks(Ms.size(), K); - const std::vector StrideAs(Ms.size(), K); - const std::vector StrideBs(Ms.size(), K); - const std::vector StrideCs(Ms.size(), N); - this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); -} - -TEST_P(RCR_F16_F16_F16, SmallCases) -{ - const std::vector Ms{2, 1, 3, 4, 5, 0}; - constexpr int N = 768; - constexpr int K = 544; - - const std::vector Ns(Ms.size(), N); - const std::vector Ks(Ms.size(), K); - const std::vector StrideAs(Ms.size(), K); - const std::vector StrideBs(Ms.size(), K); - const std::vector StrideCs(Ms.size(), N); - - this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); -} - -TEST_P(RCR_F16_F16_F16, MidCases) -{ - const std::vector Ms{167, 183, 177, 153, 139, 204}; - constexpr int N = 768; - constexpr int K = 544; - - const std::vector Ns(Ms.size(), N); - const std::vector Ks(Ms.size(), K); - const std::vector StrideAs(Ms.size(), K); - const std::vector StrideBs(Ms.size(), K); - const std::vector StrideCs(Ms.size(), N); - - this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); -} - -TEST_P(RCR_F16_F16_F16, Regular) -{ - const std::vector Ms{32, 64, 128, 256}; - constexpr int N = 768; - constexpr int K = 320; - - const std::vector Ns(Ms.size(), N); - const std::vector Ks(Ms.size(), K); - const std::vector StrideAs(Ms.size(), K); - const std::vector StrideBs(Ms.size(), K); - const std::vector StrideCs(Ms.size(), N); - - this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); -} - -TEST_P(RCR_F16_F16_F16, MNKPadded) -{ - const std::vector Ms{127, 150, 188, 210}; - constexpr int N = 136; - constexpr int K = 280; - - const std::vector Ns(Ms.size(), N); - const std::vector Ks(Ms.size(), K); - const std::vector StrideAs(Ms.size(), K); - const std::vector StrideBs(Ms.size(), K); - const std::vector StrideCs(Ms.size(), N); - - this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); -} - -TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch) +TYPED_TEST(TestGroupedGemm, TestLargeKBatch) { const std::vector Ms{188, 210}; constexpr int N = 768; @@ -157,24 +68,8 @@ TEST_P(RRR_F16_F16_F16_LargeK, TestLargeKBatch) const std::vector Ns(Ms.size(), N); const std::vector Ks(Ms.size(), K); - const std::vector StrideAs(Ms.size(), K); - const std::vector StrideBs(Ms.size(), N); - const std::vector StrideCs(Ms.size(), N); - this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); -} - -TEST_P(RCR_F16_F16_F16_LargeK, TestLargeKBatch) -{ - const std::vector Ms{188, 210}; - constexpr int N = 768; - constexpr int K = 4096; - - const std::vector Ns(Ms.size(), N); - const std::vector Ks(Ms.size(), K); - const std::vector StrideAs(Ms.size(), K); - const std::vector StrideBs(Ms.size(), K); - const std::vector StrideCs(Ms.size(), N); - - this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); + this->k_batches_ = {32, 64}; + + this->Run(Ms, Ns, Ks); } diff --git a/test/grouped_gemm/test_grouped_gemm_util.hpp b/test/grouped_gemm/test_grouped_gemm_util.hpp index 9e1395b9f8..a3ab0e087c 100644 --- a/test/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/grouped_gemm/test_grouped_gemm_util.hpp @@ -22,7 +22,6 @@ #include "ck/utility/tuple.hpp" #include "ck/utility/number.hpp" #include "profiler/profile_grouped_gemm_impl.hpp" -#include "profiler/profile_grouped_gemm_two_stage_impl.hpp" namespace ck { namespace test { @@ -40,7 +39,7 @@ std::string serialize_range(const Range& range) } template -class TestGroupedGemm : public testing::TestWithParam +class TestGroupedGemm : public testing::Test { protected: using ALayout = std::tuple_element_t<0, Tuple>; @@ -50,23 +49,77 @@ class TestGroupedGemm : public testing::TestWithParam using BDataType = std::tuple_element_t<4, Tuple>; using EDataType = std::tuple_element_t<5, Tuple>; + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + public: static constexpr bool verify_ = true; - static constexpr int init_method_ = 1; // decimal value initialization + static constexpr int init_method_ = 1; // integer value initialization static constexpr bool log_ = false; static constexpr bool bench_ = false; // measure kernel performance + static constexpr int n_warmup_ = 0; + static constexpr int n_iter_ = 1; + std::vector k_batches_; - void SetUp() override {} + void SetUp() override { k_batches_ = {1, 2, 3, 5, 8}; } + private: + template + void SetStrides(std::vector& strides, + const std::vector& rows, + const std::vector& cols) const + { + if(std::is_same_v) + { + for(const auto c : cols) + { + strides.emplace_back(c); + } + } + else if(std::is_same_v) + { + for(const auto r : rows) + { + strides.emplace_back(r); + } + } + } + + public: void Run(const std::vector& Ms, const std::vector& Ns, const std::vector& Ks, - const std::vector& StrideAs, - const std::vector& StrideBs, - const std::vector& StrideCs, - int kbatch = 1, - int n_warmup = 1, - int n_iter = 10) + const std::vector& StrideAs = {}, + const std::vector& StrideBs = {}, + const std::vector& StrideCs = {}) + { + std::vector stride_as = StrideAs; + std::vector stride_bs = StrideBs; + std::vector stride_cs = StrideCs; + + if(stride_as.empty()) + { + SetStrides(stride_as, Ms, Ks); + } + if(stride_bs.empty()) + { + SetStrides(stride_bs, Ks, Ns); + } + if(stride_cs.empty()) + { + SetStrides(stride_cs, Ms, Ns); + } + + RunSingle(Ms, Ns, Ks, stride_as, stride_bs, stride_cs, k_batches_); + } + + void RunSingle(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs, + const std::vector& kbatches) { bool pass = ck::profiler::profile_grouped_gemm_impl StrideAs, StrideBs, StrideCs, - kbatch, - n_warmup, - n_iter); - EXPECT_TRUE(pass); - } -}; - -template -class TestGroupedGemmTwoStage : public testing::TestWithParam -{ - protected: - using ALayout = std::tuple_element_t<0, Tuple>; - using BLayout = std::tuple_element_t<1, Tuple>; - using ELayout = std::tuple_element_t<2, Tuple>; - using ADataType = std::tuple_element_t<3, Tuple>; - using BDataType = std::tuple_element_t<4, Tuple>; - using EDataType = std::tuple_element_t<5, Tuple>; - - public: - static constexpr bool verify_ = true; - static constexpr int init_method_ = 1; // decimal value initialization - static constexpr bool log_ = false; - static constexpr bool bench_ = false; // measure kernel performance - - void SetUp() override {} - - void Run(const std::vector& Ms, - const std::vector& Ns, - const std::vector& Ks, - const std::vector& StrideAs, - const std::vector& StrideBs, - const std::vector& StrideCs, - int kbatch = 1, - int n_warmup = 1, - int n_iter = 10) - { - bool pass = ck::profiler::profile_grouped_gemm_two_stage_impl(verify_, - init_method_, - log_, - bench_, - Ms, - Ns, - Ks, - StrideAs, - StrideBs, - StrideCs, - kbatch, - n_warmup, - n_iter); + kbatches, + n_warmup_, + n_iter_); EXPECT_TRUE(pass); } }; @@ -263,7 +264,7 @@ struct DeviceGroupedGemmSplitkInstanceWrapper p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}); if(kbatch > 1) { - ggemm_instance.SetKBatchSize(argument, kbatch); + ggemm_instance.SetKBatchSize(&argument, kbatch); } return ggemm_instance.IsSupportedArgument(argument); @@ -300,13 +301,13 @@ struct DeviceGroupedGemmSplitkInstanceWrapper p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}); if(kbatch > 1) { - ggemm_instance.SetKBatchSize(argument, kbatch); + ggemm_instance.SetKBatchSize(&argument, kbatch); } EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument)); auto invoker = ggemm_instance.MakeInvoker(); - DeviceMem gemm_desc_workspace(ggemm_instance.GetWorkSpaceSize(&argument)); - ggemm_instance.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer()); + DeviceMem dev_gemm_kargs(ggemm_instance.GetDeviceKernelArgSize(&argument)); + ggemm_instance.SetDeviceKernelArgs(&argument, dev_gemm_kargs.GetDeviceBuffer()); return invoker.Run(argument, StreamConfig{nullptr, false}); } };