From f4e3fb591579734cba77bf295b5854fd7b48e6d1 Mon Sep 17 00:00:00 2001 From: Mateusz Ozga Date: Tue, 3 Jun 2025 13:23:34 +0000 Subject: [PATCH] Multiple ABD --- .../ck_tile/20_gemm_multi_abd/CMakeLists.txt | 1 + example/ck_tile/20_gemm_multi_abd/README.md | 35 ++++++ .../gemm_multi_abd_fp16.cpp} | 28 ++--- .../gemm_multi_abd_fp16.hpp} | 37 +++++-- .../run_gemm_multi_abd_fp16_example.inc} | 101 +++++++++-------- .../utils.hpp | 8 +- .../ck_tile/20_multi_abd_gemm/CMakeLists.txt | 1 - example/ck_tile/20_multi_abd_gemm/README.md | 33 ------ example/ck_tile/CMakeLists.txt | 2 +- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 103 ++++++++++++++++-- 10 files changed, 225 insertions(+), 124 deletions(-) create mode 100644 example/ck_tile/20_gemm_multi_abd/CMakeLists.txt create mode 100644 example/ck_tile/20_gemm_multi_abd/README.md rename example/ck_tile/{20_multi_abd_gemm/multi_abd_gemm.cpp => 20_gemm_multi_abd/gemm_multi_abd_fp16.cpp} (95%) rename example/ck_tile/{20_multi_abd_gemm/multi_abd_gemm.hpp => 20_gemm_multi_abd/gemm_multi_abd_fp16.hpp} (68%) rename example/ck_tile/{20_multi_abd_gemm/run_multi_abd_gemm_example.inc => 20_gemm_multi_abd/run_gemm_multi_abd_fp16_example.inc} (79%) rename example/ck_tile/{20_multi_abd_gemm => 20_gemm_multi_abd}/utils.hpp (89%) delete mode 100644 example/ck_tile/20_multi_abd_gemm/CMakeLists.txt delete mode 100644 example/ck_tile/20_multi_abd_gemm/README.md diff --git a/example/ck_tile/20_gemm_multi_abd/CMakeLists.txt b/example/ck_tile/20_gemm_multi_abd/CMakeLists.txt new file mode 100644 index 0000000000..f382e0cf45 --- /dev/null +++ b/example/ck_tile/20_gemm_multi_abd/CMakeLists.txt @@ -0,0 +1 @@ +add_executable(tile_example_gemm_multi_abd_fp16 EXCLUDE_FROM_ALL gemm_multi_abd_fp16.cpp) diff --git a/example/ck_tile/20_gemm_multi_abd/README.md b/example/ck_tile/20_gemm_multi_abd/README.md new file mode 100644 index 0000000000..c272df3fb5 --- /dev/null +++ b/example/ck_tile/20_gemm_multi_abd/README.md @@ -0,0 +1,35 @@ +#Multiple ABD GEMM + +This folder contains example for Multiple ABD GEMM using ck_tile tile-programming implementation. + +## build +``` +#in the root of ck_tile +mkdir build && cd build +#you can replace < arch> with the appropriate architecture(for example gfx90a or gfx942) or \ + leave it blank +sh ../script/cmake-ck-dev.sh ../ +#The basic pipeline method on the gemm calculation +make tile_example_gemm_multi_abd_fp16 -j +``` +This will result in an executable `build/bin/tile_example_gemm_multi_abd_fp16` + +## example +``` +args: + -m M dimensions - (Default: 3840) + -n N dimensions - (Default: 4096) + -k K dimensions - (Default: 4096) +-as_layout Tensor A layout (default:R) +-bs_layout Tensor B layout (default:C) +-ds_layout Tensor D layout (default:R) +-e_layout Tensor E layout (default:R) +-stride_as Tensor A strides - (Default: 0) +-stride_bs Tensor B strides - (Default: 0) +-stride_e Tensor C strides - (Default: 0) +-stride_ds Tensor D strides - (Default: 0) +-validate 0. No validation, 1. Validation on GPU. (Default: 1) + -warmup Number of iterations before benchmark the kernel. (Default: 10) + -repeat Number of iterations to benchmark the kernel. (Default: 100) + -kbatch kbatch for SplitK. (Default: 1) +``` \ No newline at end of file diff --git a/example/ck_tile/20_multi_abd_gemm/multi_abd_gemm.cpp b/example/ck_tile/20_gemm_multi_abd/gemm_multi_abd_fp16.cpp similarity index 95% rename from example/ck_tile/20_multi_abd_gemm/multi_abd_gemm.cpp rename to example/ck_tile/20_gemm_multi_abd/gemm_multi_abd_fp16.cpp index 0162c02b66..0350f22233 100644 --- a/example/ck_tile/20_multi_abd_gemm/multi_abd_gemm.cpp +++ b/example/ck_tile/20_gemm_multi_abd/gemm_multi_abd_fp16.cpp @@ -14,22 +14,22 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/host.hpp" -#include "multi_abd_gemm.hpp" +#include "gemm_multi_abd_fp16.hpp" #include "utils.hpp" template -auto multiple_abd_gemm(const multiple_abd_gemm_kargs& args, const ck_tile::stream_config& s) -> float +auto gemm_multiple_abd(const gemm_multiple_abd_kargs& args, const ck_tile::stream_config& s) -> float { #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) // Memory friendly for Interwave scheduler @@ -80,7 +80,7 @@ auto multiple_abd_gemm(const multiple_abd_gemm_kargs& args, const ck_tile::strea constexpr bool DoubleSmemBuffer = true; #endif - constexpr bool kPadM = false; + constexpr bool kPadM = false; constexpr bool kPadN = false; constexpr bool kPadK = false; @@ -98,7 +98,7 @@ auto multiple_abd_gemm(const multiple_abd_gemm_kargs& args, const ck_tile::strea using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; - using Traits = ck_tile::TileGemmTraits; + using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; using GemmPipelineProblem = ck_tile::GemmPipelineProblem; @@ -132,8 +132,8 @@ auto multiple_abd_gemm(const multiple_abd_gemm_kargs& args, const ck_tile::strea using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; using BsDataType = ck_tile::tuple; @@ -56,21 +56,34 @@ auto create_args(int argc, char* argv[]) arg_parser.insert("m", "3840", "m dimension") .insert("n", "4096", "n dimension") .insert("k", "4096", "k dimension") - .insert("a_layout", "R", "A tensor data layout - Row by default") - .insert("b_layout", "C", "B tensor data layout - Col by default") - .insert("c_layout", "R", "C tensor data layout - Row by default") - .insert("stride_a", "0", "Tensor A stride") - .insert("stride_b", "0", "Tensor B stride") - .insert("stride_d", "0", "Tensor Ds stride") - .insert("stride_c", "0", "Tensor C stride") + .insert("as_layout", "R", "As tensor data layout - Row by default") + .insert("bs_layout", "C", "Bs tensor data layout - Col by default") + .insert("ds_layout", "R", "Ds tensor data layout - Row by default") + .insert("e_layout", "R", "E tensor data layout - Row by default") + .insert("stride_as", "0", "Tensor A stride") + .insert("stride_bs", "0", "Tensor B stride") + .insert("stride_ds", "0", "Tensor Ds stride") + .insert("stride_e", "0", "Tensor E stride") .insert("v", "1", "0. No validation, 1. Validation on GPU") .insert("warmup", "50", "number of iterations before benchmark the kernel") - .insert("repeat", "100", "number of iterations to benchmark the kernel"); + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("kbatch", "1", "kbatch for SplitK"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); } +using gemm_multiple_abd_kargs = ck_tile::GemmHostArgs; -using multiple_abd_gemm_kargs = ck_tile::GemmHostArgs; - -float multiple_abd_gemm(const multiple_abd_gemm_kargs& kargs, const ck_tile::stream_config& s); +template +float gemm_multiple_abd(const gemm_multiple_abd_kargs& kargs, const ck_tile::stream_config& s); diff --git a/example/ck_tile/20_multi_abd_gemm/run_multi_abd_gemm_example.inc b/example/ck_tile/20_gemm_multi_abd/run_gemm_multi_abd_fp16_example.inc similarity index 79% rename from example/ck_tile/20_multi_abd_gemm/run_multi_abd_gemm_example.inc rename to example/ck_tile/20_gemm_multi_abd/run_gemm_multi_abd_fp16_example.inc index 7ef6487299..e54920d657 100644 --- a/example/ck_tile/20_multi_abd_gemm/run_multi_abd_gemm_example.inc +++ b/example/ck_tile/20_gemm_multi_abd/run_gemm_multi_abd_fp16_example.inc @@ -8,70 +8,70 @@ template -float invoke_multi_abd_gemm(const std::array& as_m_k_dev_buf, +float invoke_gemm_multi_abd(const std::array& as_m_k_dev_buf, const std::array& bs_k_n_dev_buf, [[maybe_unused]] const std::array& ds_m_n_dev_buf, - void* c_m_n_dev_buf, + void* e_m_n_dev_buf, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, const std::array& StrideAs, const std::array& StrideBs, [[maybe_unused]] const std::array& StrideDs, - ck_tile::index_t StrideC, + ck_tile::index_t StrideE, int n_warmup, - int n_repeat) + int n_repeat, + int k_batch) { - multiple_abd_gemm_kargs gemm_descs({as_m_k_dev_buf, + gemm_multiple_abd_kargs gemm_descs({as_m_k_dev_buf, bs_k_n_dev_buf, //ds_m_n_dev_buf, - c_m_n_dev_buf, - /*kbatch */ 1, + e_m_n_dev_buf, + k_batch, M, N, K, StrideAs, StrideBs, //StrideDs, - StrideC}); + StrideE}); - float ave_time = multiple_abd_gemm( - gemm_descs, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + gemm_descs, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); - std::string op_name{"Multiple-D Gemm"}; - //static constexpr ck_tile::index_t NumDTensor = DsDataType::size(); + std::string op_name{"Gemm Multiple-ABD"}; std::size_t flop = 0, num_btype = 0; flop += std::size_t(2) * M * N * K; - num_btype += sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(CDataType) * M * N; + num_btype += sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Run Multiple-D Gemm kernel with:\n"; + std::cout << "Run Gemm Multiple-ABD kernel with:\n"; std::cout << "M =" << M << " N =" << N << " K =" << K << "\n"; - //std::cout << "StrideA = " << StrideA << " StrideB = " << StrideB << " StrideC = " << StrideC + //std::cout << "StrideA = " << StrideA << " StrideB = " << StrideB << " StrideE = " << StrideE // << "\n"; std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << "\n"; @@ -85,8 +85,8 @@ template -int run_multiple_abd_gemm_example_with_layouts(int argc, + typename ELayout> +int run_gemm_multi_abd_example_with_layouts(int argc, char* argv[], const A0Layout a0_layout = A0Layout{}, const A1Layout a1_layout = A1Layout{}, @@ -94,7 +94,7 @@ int run_multiple_abd_gemm_example_with_layouts(int argc, const B1Layout b1_layout = B1Layout{}, const D0Layout d0_layout = D0Layout{}, const D1Layout d1_layout = D1Layout{}, - const CLayout c_layout = CLayout{}) + const ELayout e_layout = ELayout{}) { auto [result, arg_parser] = create_args(argc, argv); if(!result) @@ -112,10 +112,10 @@ int run_multiple_abd_gemm_example_with_layouts(int argc, ck_tile::index_t N = arg_parser.get_int("n"); ck_tile::index_t K = arg_parser.get_int("k"); - ck_tile::index_t StrideA = arg_parser.get_int("stride_a"); - ck_tile::index_t StrideB = arg_parser.get_int("stride_b"); - ck_tile::index_t StrideD = arg_parser.get_int("stride_d"); - ck_tile::index_t StrideC = arg_parser.get_int("stride_c"); + ck_tile::index_t StrideA = arg_parser.get_int("stride_as"); + ck_tile::index_t StrideB = arg_parser.get_int("stride_bs"); + ck_tile::index_t StrideD = arg_parser.get_int("stride_ds"); + ck_tile::index_t StrideE = arg_parser.get_int("stride_e"); ck_tile::index_t StrideA0 = StrideA; ck_tile::index_t StrideA1 = StrideA; @@ -128,6 +128,7 @@ int run_multiple_abd_gemm_example_with_layouts(int argc, const int n_warmup = arg_parser.get_int("warmup"); const int n_repeat = arg_parser.get_int("repeat"); + const int k_batch = arg_parser.get_int("kbatch"); StrideA0 = f_get_default_stride(M, N, StrideA0, a0_layout); StrideA1 = f_get_default_stride(M, N, StrideA1, a1_layout); @@ -138,7 +139,7 @@ int run_multiple_abd_gemm_example_with_layouts(int argc, StrideD0 = f_get_default_stride(M, N, StrideD0, d0_layout); StrideD1 = f_get_default_stride(M, N, StrideD1, d1_layout); - StrideC = f_get_default_stride(M, N, StrideC, c_layout); + StrideE = f_get_default_stride(M, N, StrideE, e_layout); ck_tile::HostTensor a0_m_k_tesnor( f_host_tensor_descriptor(M, K, StrideA0, a0_layout)); @@ -155,8 +156,8 @@ int run_multiple_abd_gemm_example_with_layouts(int argc, ck_tile::HostTensor d1_m_n_tensors( f_host_tensor_descriptor(M, N, StrideD1, d1_layout)); - ck_tile::HostTensor e_m_n_device_result( - f_host_tensor_descriptor(M, N, StrideC, c_layout)); + ck_tile::HostTensor e_m_n_device_result( + f_host_tensor_descriptor(M, N, StrideE, e_layout)); ck_tile::FillUniformDistribution{-1.f, 1.f}(a0_m_k_tesnor); ck_tile::FillUniformDistribution{-1.f, 1.f}(a1_m_k_tesnor); @@ -176,7 +177,7 @@ int run_multiple_abd_gemm_example_with_layouts(int argc, ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n_tensors.get_element_space_size_in_bytes()); ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n_tensors.get_element_space_size_in_bytes()); - ck_tile::DeviceMem c_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes()); + ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes()); a0_m_k_dev_buf.ToDevice(a0_m_k_tesnor.mData.data()); a1_m_k_dev_buf.ToDevice(a1_m_k_tesnor.mData.data()); @@ -187,7 +188,7 @@ int run_multiple_abd_gemm_example_with_layouts(int argc, d0_m_n_dev_buf.ToDevice(d0_m_n_tensors.mData.data()); d1_m_n_dev_buf.ToDevice(d1_m_n_tensors.mData.data()); - c_m_n_dev_buf.SetZero(); + e_m_n_dev_buf.SetZero(); e_m_n_device_result.SetZero(); std::array as_ptr_buf = {a0_m_k_dev_buf.GetDeviceBuffer(), @@ -203,36 +204,37 @@ int run_multiple_abd_gemm_example_with_layouts(int argc, std::array strideBs = {StrideB0, StrideB1}; std::array strideDs = {StrideD0, StrideD1}; - invoke_multi_abd_gemm(as_ptr_buf, bs_ptr_buf, ds_ptr_buf, - c_m_n_dev_buf.GetDeviceBuffer(), + e_m_n_dev_buf.GetDeviceBuffer(), M, N, K, strideAs, strideBs, strideDs, - StrideC, + StrideE, n_warmup, - n_repeat); + n_repeat, + k_batch); - c_m_n_dev_buf.FromDevice(e_m_n_device_result.data()); + e_m_n_dev_buf.FromDevice(e_m_n_device_result.data()); ck_tile::HostTensor a_m_k_host_ref(f_host_tensor_descriptor(M, K, StrideA0, a0_layout)); ck_tile::HostTensor b_k_n_host_ref(f_host_tensor_descriptor(K, N, StrideB0, b0_layout)); - ck_tile::HostTensor e_m_n_host_ref(f_host_tensor_descriptor(M, N, StrideC, c_layout)); + ck_tile::HostTensor e_m_n_host_ref(f_host_tensor_descriptor(M, N, StrideE, e_layout)); a_m_k_host_ref.SetZero(); b_k_n_host_ref.SetZero(); e_m_n_host_ref.SetZero(); @@ -242,7 +244,7 @@ int run_multiple_abd_gemm_example_with_layouts(int argc, B0DataType, D0DataType, AccDataType, - CDataType, + EDataType, AElementWiseFn, BElementWiseFn, CDEElementWiseFn>( @@ -279,18 +281,15 @@ int run_multiple_abd_gemm_example(int argc, char* argv[]) return -1; } - const std::string a_layout = arg_parser.get_str("a_layout"); - const std::string b_layout = arg_parser.get_str("b_layout"); + const std::string as_layout = arg_parser.get_str("as_layout"); + const std::string bs_layout = arg_parser.get_str("bs_layout"); using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - - if(a_layout == "R" && b_layout == "C") + if(as_layout == "R" && bs_layout == "C") { - return run_multiple_abd_gemm_example_with_layouts( + return run_gemm_multi_abd_example_with_layouts( argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}, Row{}, Row{}); } else diff --git a/example/ck_tile/20_multi_abd_gemm/utils.hpp b/example/ck_tile/20_gemm_multi_abd/utils.hpp similarity index 89% rename from example/ck_tile/20_multi_abd_gemm/utils.hpp rename to example/ck_tile/20_gemm_multi_abd/utils.hpp index f8f4bb4812..5b333d09d1 100644 --- a/example/ck_tile/20_multi_abd_gemm/utils.hpp +++ b/example/ck_tile/20_gemm_multi_abd/utils.hpp @@ -45,17 +45,17 @@ auto calculate_rtol_atol(const ck_tile::index_t K, std::conditional_t; // Calculate thresholds - const auto rtol = ck_tile::get_relative_threshold( + const auto rtol = ck_tile::get_relative_threshold( ck_tile::integer_divide_ceil(K, kbatch)); - const auto atol = ck_tile::get_absolute_threshold( + const auto atol = ck_tile::get_absolute_threshold( max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); // Calculate error due to split_k accumulation const auto rtol_split_k = - ck_tile::get_relative_threshold(kbatch); + ck_tile::get_relative_threshold(kbatch); - const auto atol_split_k = ck_tile::get_absolute_threshold( + const auto atol_split_k = ck_tile::get_absolute_threshold( max_accumulated_value, kbatch); // Use higher threshold diff --git a/example/ck_tile/20_multi_abd_gemm/CMakeLists.txt b/example/ck_tile/20_multi_abd_gemm/CMakeLists.txt deleted file mode 100644 index 751d80b96e..0000000000 --- a/example/ck_tile/20_multi_abd_gemm/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_executable(tile_example_multi_abd_gemm EXCLUDE_FROM_ALL multi_abd_gemm.cpp) diff --git a/example/ck_tile/20_multi_abd_gemm/README.md b/example/ck_tile/20_multi_abd_gemm/README.md deleted file mode 100644 index b9d6a4c8b4..0000000000 --- a/example/ck_tile/20_multi_abd_gemm/README.md +++ /dev/null @@ -1,33 +0,0 @@ -#Multiple D GEMM - -This folder contains example for Multiple D GEMM using ck_tile tile-programming implementation. - -## build -``` -#in the root of ck_tile -mkdir build && cd build -#you can replace < arch> with the appropriate architecture(for example gfx90a or gfx942) or \ - leave it blank -sh ../script/cmake-ck-dev.sh ../ -#The basic pipeline method on the gemm calculation -make tile_example_multi_d_gemm -j -``` -This will result in an executable `build/bin/tile_example_multi_d_gemm` - -## example -``` -args: - -m M dimensions - (Default: 3840) - -n N dimensions - (Default: 4096) - -k K dimensions - (Default: 4096) --a_layout Tensor A layout (default:R) --b_layout Tensor B layout (default:C) --c_layout Tensor C layout (default:R) --stride_a Tensor A strides - (Default: 0) --stride_b Tensor B strides - (Default: 0) --stride_c Tensor C strides - (Default: 0) --stride_d Tensor C strides - (Default: 0) --validate 0. No validation, 1. Validation on GPU. (Default: 1) - -warmup Number of iterations before benchmark the kernel. (Default: 10) - -repeat Number of iterations to benchmark the kernel. (Default: 100) -``` diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 3c72bbc3b0..cd969a3938 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -18,6 +18,6 @@ add_subdirectory(15_fused_moe) add_subdirectory(16_batched_gemm) add_subdirectory(17_grouped_gemm) add_subdirectory(18_flatmm) -add_subdirectory(20_multi_abd_gemm) +add_subdirectory(20_gemm_multi_abd) add_subdirectory(35_batched_transpose) add_subdirectory(36_copy) diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 7e0876ee55..ef29fd2eca 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -13,6 +13,93 @@ namespace ck_tile { +/// @brief The GEMM kernel host arguments. +/// +/// @par Overview +/// This structure is passed to @ref GemmKernel "GemmKernel" when creating kernel arguments +/// object. It contain all necessary information required to build proper kernel argument +/// and launch kernel on GPU. +/// This structure defines the GEMM problem configuration by stating all required information +/// like M,N,K sizes and respective strides. +/// NumDTensor describes the number of D tensors. +template +struct GemmHostArgs +{ + CK_TILE_HOST GemmHostArgs() = default; + CK_TILE_HOST GemmHostArgs(const std::array& as_ptr_, + const std::array& bs_ptr_, + const std::array& ds_ptr_, + void* e_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + const std::array& stride_As_, + const std::array& stride_Bs_, + const std::array& stride_Ds_, + index_t stride_E_) + : a_ptr(a_ptr_), + b_ptr(b_ptr_), + ds_ptr(ds_ptr_), + e_ptr(e_ptr_), + M(M_), + N(N_), + K(K_), + stride_A(stride_A_), + stride_B(stride_B_), + stride_Ds(stride_Ds_), + stride_E(stride_E_), + k_batch(k_batch_) + { + } + + const std::array as_ptr; + const std::array bs_ptr; + const std::array ds_ptr; + void* e_ptr; + index_t M; + index_t N; + index_t K; + const std::array stride_As; + const std::array stride_Bs; + const std::array stride_Ds; + index_t stride_E; + index_t k_batch; +}; + +/// @brief The GEMM kernel device arguments. +template , typename BType = ck_tile::tuple<>, typename DType = ck_tile::tuple<>> +struct GemmKernelArgs +{ + /// @brief The A input tensor's pointer to device memory. + const AType* a_ptr; + /// @brief The B input tensor's pointer to device memory. + const BType* b_ptr; + /// @brief The B input tensor's pointer to device memory. + const DType ds_ptr; + /// @brief The E output tensor's pointer to device memory. + void* e_ptr; + /// @brief GEMM's M dimension size. + index_t M; + /// @brief GEMM's N dimension size. + index_t N; + /// @brief GEMM's K dimension size. + index_t K; + /// @brief The distance between consecutive elements of non-contiguous dimension + /// (in memory) of As tensor. + std::array stride_As; + /// @brief The distance between consecutive elements of non-contiguous dimension + /// (in memory) of Bs tensor. + std::array stride_Bs; + /// @brief The distance between consecutive elements of non-contiguous dimension + /// (in memory) of Ds tensor. + std::array stride_Ds; + /// @brief The distance between consecutive elements of non-contiguous dimension + /// (in memory) of E tensor. + index_t stride_E; + index_t k_batch; +}; + template struct GemmHostArgs { @@ -130,7 +217,7 @@ struct GemmKernel using AsDataType = remove_cvref_t; using BsDataType = remove_cvref_t; // Below type is actually accumulation data type - the output of block GEMM. - using CDataType = remove_cvref_t; + using EDataType = remove_cvref_t; static constexpr auto I0 = number<0>(); static constexpr auto I1 = number<1>(); @@ -267,7 +354,7 @@ struct GemmKernel CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs) { if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value) + is_any_of::value) { if(kargs.k_batch != 1) { @@ -417,7 +504,7 @@ struct GemmKernel template CK_TILE_DEVICE static auto MakeGemmTensorViews(const AsGridPointer as_ptr, const BsGridPointer bs_ptr, - CDataType* c_ptr, + EDataType* c_ptr, const GemmKernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset) { @@ -668,7 +755,7 @@ struct GemmKernel */ CK_TILE_DEVICE static void RunGemm(const AsGridPointer as_ptr, const BsGridPointer bs_ptr, - CDataType* c_ptr, + EDataType* c_ptr, void* smem_ptr_0, const GemmKernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset, @@ -720,7 +807,7 @@ struct GemmKernel */ CK_TILE_DEVICE static void RunGemm2LDS(const AsGridPointer as_ptr, const BsGridPointer bs_ptr, - CDataType* c_ptr, + EDataType* c_ptr, void* __restrict__ smem_ptr_0, void* __restrict__ smem_ptr_1, const GemmKernelArgs& kargs, @@ -774,7 +861,7 @@ struct GemmKernel bs_ptr(i) = static_cast(kargs.bs_ptr[i]) + splitk_batch_offset.b_k_split_offset[i]; }); - CDataType* c_ptr = static_cast(kargs.e_ptr); + EDataType* c_ptr = static_cast(kargs.e_ptr); // allocate LDS __shared__ char smem_ptr_0[GetSmemSize()]; @@ -784,7 +871,7 @@ struct GemmKernel __shared__ char smem_ptr_1[GetSmemSize()]; if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) + is_any_of::value)) { RunGemm2LDS(as_ptr, bs_ptr, @@ -801,7 +888,7 @@ struct GemmKernel { if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) + is_any_of::value)) { RunGemm(as_ptr, bs_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); }