From 8ce06348b5bb8f0606ef8d332b301c9527d3d562 Mon Sep 17 00:00:00 2001 From: Mateusz Ozga Date: Sat, 22 Mar 2025 18:39:35 +0000 Subject: [PATCH] Simple HostArgs struct --- example/ck_tile/03_gemm/gemm_basic.cpp | 2 +- example/ck_tile/03_gemm/gemm_utils.hpp | 3 +-- example/ck_tile/03_gemm/run_gemm_example.inc | 24 ++++++++--------- example/ck_tile/03_gemm/universal_gemm.cpp | 2 +- .../ck_tile/17_grouped_gemm/grouped_gemm.hpp | 2 +- .../ck_tile/18_multi_d_gemm/multi_d_gemm.hpp | 2 +- .../run_multi_d_gemm_example.inc | 8 +++--- .../ops/gemm/kernel/batched_gemm_kernel.hpp | 2 +- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 16 +++++------ .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 10 +++---- test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 5 ++-- .../grouped_gemm/test_grouped_gemm_util.hpp | 2 +- .../test_multiple_d_gemm_util.hpp | 27 +++++++++---------- 13 files changed, 49 insertions(+), 56 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index e520984605..4294fed082 100755 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -22,7 +22,7 @@ template -float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 1170b56aa0..3254a407fd 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -219,5 +219,4 @@ auto create_args(int argc, char* argv[]) } // host API -float gemm_calc(const ck_tile::GemmHostArgs& args, - const ck_tile::stream_config& s); +float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 2d5daa30e4..35732dcc90 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -166,18 +166,18 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, int n_warmup, int n_repeat) { - ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(), - b_k_n_dev_buf.GetDeviceBuffer(), - {}, - c_m_n_dev_buf.GetDeviceBuffer(), - kbatch, - M, - N, - K, - stride_A, - stride_B, - {}, - stride_C}; + ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + {}, + c_m_n_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + {}, + stride_C}; float ave_time = gemm -float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { using GemmShape = ck_tile::TileGemmShape< diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 605f43ace1..9674e69b33 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -54,7 +54,7 @@ using BDataType = Types::BDataType; using AccDataType = Types::AccDataType; using CDataType = Types::CDataType; -using grouped_gemm_kargs = ck_tile::GemmHostArgs; +using grouped_gemm_kargs = ck_tile::GemmHostArgs; auto create_args(int argc, char* argv[]) { diff --git a/example/ck_tile/18_multi_d_gemm/multi_d_gemm.hpp b/example/ck_tile/18_multi_d_gemm/multi_d_gemm.hpp index 8a3aa981a8..f1afaabe05 100644 --- a/example/ck_tile/18_multi_d_gemm/multi_d_gemm.hpp +++ b/example/ck_tile/18_multi_d_gemm/multi_d_gemm.hpp @@ -62,6 +62,6 @@ auto create_args(int argc, char* argv[]) return std::make_tuple(result, arg_parser); } -using multiple_d_gemm_kargs = ck_tile::GemmHostArgs; +using multiple_d_gemm_kargs = ck_tile::GemmHostArgs; float multiple_d_gemm(const multiple_d_gemm_kargs& kargs, const ck_tile::stream_config& s); diff --git a/example/ck_tile/18_multi_d_gemm/run_multi_d_gemm_example.inc b/example/ck_tile/18_multi_d_gemm/run_multi_d_gemm_example.inc index 0e61fe7c9d..4868f34ce6 100644 --- a/example/ck_tile/18_multi_d_gemm/run_multi_d_gemm_example.inc +++ b/example/ck_tile/18_multi_d_gemm/run_multi_d_gemm_example.inc @@ -16,21 +16,21 @@ template float invoke_multi_d_gemm(const void* a_m_k_dev_buf, const void* b_k_n_dev_buf, - const std::array& d_m_n_dev_buf, + const std::array& ds_m_n_dev_buf, void* c_m_n_dev_buf, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t StrideA, ck_tile::index_t StrideB, - const std::array StrideDs, + const std::array& StrideDs, ck_tile::index_t StrideC, int n_warmup, int n_repeat) { multiple_d_gemm_kargs gemm_descs({a_m_k_dev_buf, b_k_n_dev_buf, - d_m_n_dev_buf, + ds_m_n_dev_buf.data(), c_m_n_dev_buf, /*kbatch */ 1, M, @@ -38,7 +38,7 @@ float invoke_multi_d_gemm(const void* a_m_k_dev_buf, K, StrideA, StrideB, - StrideDs, + StrideDs.data(), StrideC}); float ave_time = multiple_d_gemm +struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs { CK_TILE_HOST BatchedGemmHostArgs() = default; CK_TILE_HOST BatchedGemmHostArgs(const void* a_ptr_, diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 70b7da98b4..9b118783f6 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -12,13 +12,12 @@ namespace ck_tile { -template struct GemmHostArgs { CK_TILE_HOST GemmHostArgs() = default; CK_TILE_HOST GemmHostArgs(const void* a_ptr_, const void* b_ptr_, - const std::array& ds_ptr_, + const void* ds_ptr_, void* c_ptr_, index_t k_batch_, index_t M_, @@ -26,7 +25,7 @@ struct GemmHostArgs index_t K_, index_t stride_A_, index_t stride_B_, - const std::array& stride_Ds_, + const index_t* stride_Ds_, index_t stride_C_) : a_ptr(a_ptr_), b_ptr(b_ptr_), @@ -45,14 +44,14 @@ struct GemmHostArgs const void* a_ptr; const void* b_ptr; - const std::array ds_ptr; + const void* ds_ptr; void* c_ptr; index_t M; index_t N; index_t K; index_t stride_A; index_t stride_B; - const std::array stride_Ds; + const index_t* stride_Ds; index_t stride_C; index_t k_batch; }; @@ -126,19 +125,18 @@ struct GemmKernel CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } - CK_TILE_HOST static constexpr GemmKernelArgs - MakeKernelArgs(const GemmHostArgs& hostArgs) + CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs) { return GemmKernelArgs{hostArgs.a_ptr, hostArgs.b_ptr, - static_cast(hostArgs.ds_ptr.data()), + hostArgs.ds_ptr, // static_cast(hostArgs.ds_ptr.data()), hostArgs.c_ptr, hostArgs.M, hostArgs.N, hostArgs.K, hostArgs.stride_A, hostArgs.stride_B, - hostArgs.stride_Ds.data(), + hostArgs.stride_Ds, hostArgs.stride_C, hostArgs.k_batch}; } diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index b7dc9bdabc..652e107d33 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -55,16 +55,15 @@ struct GroupedGemmKernel : public GemmKernel>& gemm_descs) -> std::size_t + __host__ static auto GetWorkSpaceSize(const std::vector& gemm_descs) + -> std::size_t { return gemm_descs.size() * sizeof(GemmTransKernelArg); } __host__ static constexpr auto BlockSize() -> dim3 { return dim3(KernelBlockSize); } - __host__ static constexpr auto - GridSize(const std::vector>& gemm_descs) + __host__ static constexpr auto GridSize(const std::vector& gemm_descs) { index_t grid_size = 0; for(const auto& it_desc : gemm_descs) @@ -75,8 +74,7 @@ struct GroupedGemmKernel : public GemmKernel>& gemm_descs) + CK_TILE_HOST static auto MakeKargs(const std::vector& gemm_descs) -> std::vector { std::vector gemm_kernel_args_; diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index a1e4bdb574..8a49392bc3 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -82,8 +82,7 @@ class TestCkTileGemmPipeline : public ::testing::Test // TODO: expose tile size through test t-param ? template - void invoke_gemm(const ck_tile::GemmHostArgs& args, - const ck_tile::stream_config& s) + void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { // TODO: This should be parameterized in tests constexpr ck_tile::index_t M_Tile = 256; @@ -424,7 +423,7 @@ class TestCkTileGemmPipeline : public ::testing::Test c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - ck_tile::GemmHostArgs args; + ck_tile::GemmHostArgs args; args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp index ce1e48eb8e..d729522d06 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp @@ -47,7 +47,7 @@ class TestCkTileGroupedGemm : public ::testing::Test static const ck_tile::index_t K_Warp_Tile = 8; }; - using grouped_gemm_kargs = ck_tile::GemmHostArgs; + using grouped_gemm_kargs = ck_tile::GemmHostArgs; std::size_t get_workspace_size(const std::vector& gemm_descs) { return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); diff --git a/test/ck_tile/multiple_d_gemm/test_multiple_d_gemm_util.hpp b/test/ck_tile/multiple_d_gemm/test_multiple_d_gemm_util.hpp index 17c25869a7..77d8d1bf81 100644 --- a/test/ck_tile/multiple_d_gemm/test_multiple_d_gemm_util.hpp +++ b/test/ck_tile/multiple_d_gemm/test_multiple_d_gemm_util.hpp @@ -64,8 +64,7 @@ class TestCkTileMultipleDGemm : public ::testing::Test typename DsLayout, typename CLayout, typename CDEElementWise = ck_tile::element_wise::PassThrough> - void invoke_multi_d_gemm(const ck_tile::GemmHostArgs& args, - const ck_tile::stream_config& s) + void invoke_multi_d_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { constexpr ck_tile::index_t M_Tile = 256; constexpr ck_tile::index_t N_Tile = 256; @@ -291,18 +290,18 @@ class TestCkTileMultipleDGemm : public ::testing::Test d1_m_n_dev_buf.GetDeviceBuffer()}; std::array stridesDs = {StrideD0, StrideD1}; - ck_tile::GemmHostArgs args({a_m_k_dev_buf.GetDeviceBuffer(), - b_k_n_dev_buf.GetDeviceBuffer(), - ds_ptr_buf, - c_m_n_dev_buf.GetDeviceBuffer(), - /* kBatch */ 1, - M, - N, - K, - StrideA, - StrideB, - stridesDs, - StrideC}); + ck_tile::GemmHostArgs args({a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + ds_ptr_buf.data(), + c_m_n_dev_buf.GetDeviceBuffer(), + /* kBatch */ 1, + M, + N, + K, + StrideA, + StrideB, + stridesDs.data(), + StrideC}); invoke_multi_d_gemm