From d9025d054d3dadeacd079faf38720165ca1389e5 Mon Sep 17 00:00:00 2001 From: aledudek Date: Wed, 18 Dec 2024 17:52:46 +0100 Subject: [PATCH] [CK TILE] Refactor GemmKernel to be reused by other GEMM related operators (#1730) * Gemm Kernel Refactor part1 * Gemm Kernel Refactor common gemm pipeline part2 * [CK TILE] Refactor batched gemm to reuse GemmKernel * [CK TILE] Refactor GemmKernel - review changes part1 * [CK TILE] Refactor GemmKernel - references fix * [CK TILE] Refactor GemmKernel - naming changes, add problem * [CK_TILE] Refactor GemmKernel - update tests * [CK_TILE] Refactor GemmKernel - review changes * [CK_TILE] Refactor GemmKernel - update test * [CK_TILE] Refactor GemmKernel - constness fixes * [CK_TILE] Refactor GemmKernel - update tests [ROCm/composable_kernel commit: 453ca373479e1c3510bff66c03a773a29f1caada] --- example/ck_tile/03_gemm/gemm_basic.cpp | 14 +- example/ck_tile/03_gemm/gemm_basic.hpp | 16 +- example/ck_tile/03_gemm/run_gemm_example.inc | 10 +- .../ck_tile/16_batched_gemm/batched_gemm.cpp | 6 +- .../ck_tile/16_batched_gemm/batched_gemm.hpp | 6 +- .../run_batched_gemm_example.inc | 2 +- .../ops/gemm/kernel/batched_gemm_kernel.hpp | 274 +++++------------- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 259 ++++++++++++----- .../batched_gemm/test_batched_gemm_util.hpp | 40 ++- test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 38 +-- 10 files changed, 297 insertions(+), 368 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index f5260c306e..4c630375f4 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -15,7 +15,7 @@ #include "gemm_basic.hpp" template -float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) +float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. constexpr bool kPadM = false; @@ -79,17 +79,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKargs(args.p_a, - args.p_b, - args.p_c, - args.M, - args.N, - args.K, - args.stride_A, - args.stride_B, - args.stride_C); + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); constexpr dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) diff --git a/example/ck_tile/03_gemm/gemm_basic.hpp b/example/ck_tile/03_gemm/gemm_basic.hpp index 23e99bc2a8..58cdaea7d8 100644 --- a/example/ck_tile/03_gemm/gemm_basic.hpp +++ b/example/ck_tile/03_gemm/gemm_basic.hpp @@ -51,20 +51,6 @@ using BDataType = Types::BDataType; using AccDataType = Types::AccDataType; using CDataType = Types::CDataType; -struct gemm_basic_args -{ - const void* p_a; - const void* p_b; - void* p_c; - ck_tile::index_t kbatch; - ck_tile::index_t M; - ck_tile::index_t N; - ck_tile::index_t K; - ck_tile::index_t stride_A; - ck_tile::index_t stride_B; - ck_tile::index_t stride_C; -}; - auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; @@ -89,4 +75,4 @@ auto create_args(int argc, char* argv[]) } // host API -float gemm_calc(gemm_basic_args args, const ck_tile::stream_config& s); +float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 2b7a967bab..68df389bfc 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -16,11 +16,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, int n_warmup, int n_repeat) { - gemm_basic_args args; - args.p_a = a_m_k_dev_buf.GetDeviceBuffer(); - args.p_b = b_k_n_dev_buf.GetDeviceBuffer(); - args.p_c = c_m_n_dev_buf.GetDeviceBuffer(); - args.kbatch = kbatch; + ck_tile::GemmHostArgs args; + args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); + args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); + args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); + args.k_batch = kbatch; args.M = M; args.N = N; args.K = K; diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.cpp b/example/ck_tile/16_batched_gemm/batched_gemm.cpp index bfdd74126e..9b4ed9a9e7 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.cpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.cpp @@ -16,7 +16,7 @@ #include "batched_gemm.hpp" template -float batched_gemm(const batched_gemm_kargs& args, const ck_tile::stream_config& s) +float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s) { // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. constexpr bool kPadM = false; @@ -79,9 +79,9 @@ float batched_gemm(const batched_gemm_kargs& args, const ck_tile::stream_config& // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. using Kernel = ck_tile::BatchedGemmKernel; - auto kargs = Kernel::MakeKargs(args); + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count); constexpr dim3 blocks = Kernel::BlockSize(); if(s.log_level_ > 0) diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.hpp b/example/ck_tile/16_batched_gemm/batched_gemm.hpp index e252c0f673..f0c0c9efba 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.hpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.hpp @@ -29,10 +29,6 @@ using BDataType = Types::BDataType; using AccDataType = Types::AccDataType; using CDataType = Types::CDataType; -struct batched_gemm_kargs : public ck_tile::BatchedGemmHostArgs -{ -}; - auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; @@ -60,4 +56,4 @@ auto create_args(int argc, char* argv[]) } // host API -float batched_gemm(batched_gemm_kargs args, const ck_tile::stream_config& s); +float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc index 8345eef95b..4e7218b5b1 100644 --- a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc +++ b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc @@ -20,7 +20,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, int n_warmup, int n_repeat) { - batched_gemm_kargs args; + ck_tile::BatchedGemmHostArgs args; args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index 07b4af5730..07a4cf8fbe 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -3,90 +3,93 @@ #pragma once -#include -#include - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" namespace ck_tile { -struct BatchedGemmHostArgs +struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs { - const void* a_ptr; - const void* b_ptr; - void* c_ptr; - index_t M; - index_t N; - index_t K; - index_t stride_A; - index_t stride_B; - index_t stride_C; - index_t batch_stride_A; - index_t batch_stride_B; - index_t batch_stride_C; - index_t batch_count; + CK_TILE_HOST BatchedGemmHostArgs() = default; + CK_TILE_HOST BatchedGemmHostArgs(const void* a_ptr_, + const void* b_ptr_, + void* c_ptr_, + ck_tile::index_t k_batch_, + ck_tile::index_t M_, + ck_tile::index_t N_, + ck_tile::index_t K_, + ck_tile::index_t stride_A_, + ck_tile::index_t stride_B_, + ck_tile::index_t stride_C_, + ck_tile::index_t batch_stride_A_, + ck_tile::index_t batch_stride_B_, + ck_tile::index_t batch_stride_C_, + ck_tile::index_t batch_count_) + : GemmHostArgs( + a_ptr_, b_ptr_, c_ptr_, k_batch_, M_, N_, K_, stride_A_, stride_B_, stride_C_), + batch_stride_A(batch_stride_A_), + batch_stride_B(batch_stride_B_), + batch_stride_C(batch_stride_C_), + batch_count(batch_count_) + { + } + + ck_tile::index_t batch_stride_A; + ck_tile::index_t batch_stride_B; + ck_tile::index_t batch_stride_C; + ck_tile::index_t batch_count; }; template -struct BatchedGemmKernel +struct BatchedGemmKernel : public GemmKernel { - using TilePartitioner = remove_cvref_t; - using GemmPipeline = remove_cvref_t; - using EpiloguePipeline = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + using Base = GemmKernel; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using GemmKernelArgs = typename Base::GemmKernelArgs; - struct BatchedGemmKargs + using ADataType = typename Base::ADataType; + using BDataType = typename Base::BDataType; + using CDataType = typename Base::CDataType; + + using TilePartitioner = typename Base::TilePartitioner; + using GemmPipeline = typename Base::GemmPipeline; + using EpiloguePipeline = typename Base::EpiloguePipeline; + using ALayout = typename Base::ALayout; + using BLayout = typename Base::BLayout; + using CLayout = typename Base::CLayout; + + struct BatchedGemmKernelArgs : GemmKernelArgs { - const void* a_ptr; - const void* b_ptr; - void* c_ptr; - index_t M; - index_t N; - index_t K; - index_t stride_A; - index_t stride_B; - index_t stride_C; index_t batch_stride_A; index_t batch_stride_B; index_t batch_stride_C; index_t batch_count; }; - using Kargs = BatchedGemmKargs; - using Hargs = BatchedGemmHostArgs; + using KernelArgs = BatchedGemmKernelArgs; - __host__ static constexpr auto GridSize(const Hargs& h) + __host__ static constexpr auto GridSize(index_t M, index_t N, index_t batch_count) { - return TilePartitioner::GridSize(h.M, h.N, h.batch_count); + return TilePartitioner::GridSize(M, N, batch_count); } - __host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + __host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); } - CK_TILE_HOST static constexpr BatchedGemmKargs MakeKargs(const Hargs& h) + CK_TILE_HOST static constexpr BatchedGemmKernelArgs + MakeKernelArgs(const BatchedGemmHostArgs& hostArgs) { - Kargs k; - k.a_ptr = h.a_ptr; - k.b_ptr = h.b_ptr; - k.c_ptr = h.c_ptr; - k.M = h.M; - k.N = h.N; - k.K = h.K; - k.stride_A = h.stride_A; - k.stride_B = h.stride_B; - k.stride_C = h.stride_C; - k.batch_stride_A = h.batch_stride_A; - k.batch_stride_B = h.batch_stride_B; - k.batch_stride_C = h.batch_stride_C; - k.batch_count = h.batch_count; - return k; + return BatchedGemmKernelArgs{{hostArgs.a_ptr, + hostArgs.b_ptr, + hostArgs.c_ptr, + hostArgs.M, + hostArgs.N, + hostArgs.K, + hostArgs.stride_A, + hostArgs.stride_B, + hostArgs.stride_C}, + hostArgs.batch_stride_A, + hostArgs.batch_stride_B, + hostArgs.batch_stride_C, + hostArgs.batch_count}; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() @@ -94,7 +97,7 @@ struct BatchedGemmKernel return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); } - CK_TILE_DEVICE void operator()(Kargs kargs) const + CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const { const auto [i_m, i_n] = TilePartitioner{}(); const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z); @@ -102,156 +105,17 @@ struct BatchedGemmKernel // options const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A); const auto batch_offset_A = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_A); - const ADataType* a_start = static_cast(kargs.a_ptr); + const ADataType* a_ptr = static_cast(kargs.a_ptr) + batch_offset_A; const auto batch_stride_B = __builtin_amdgcn_readfirstlane(kargs.batch_stride_B); const auto batch_offset_B = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_B); - const BDataType* b_start = static_cast(kargs.b_ptr); - - // Convert pointers to tensor views - auto a_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - a_start + batch_offset_A, - make_tuple(kargs.M, kargs.K), - make_tuple(kargs.stride_A, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - a_start + batch_offset_A, - make_tuple(kargs.M, kargs.K), - make_tuple(1, kargs.stride_A), - number<1>{}, - number<1>{}); - } - }(); - - auto b_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - b_start + batch_offset_B, - make_tuple(kargs.N, kargs.K), - make_tuple(1, kargs.stride_B), - number<1>{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - b_start + batch_offset_B, - make_tuple(kargs.N, kargs.K), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); - } - }(); - - auto a_pad_view = [&]() { - if constexpr(std::is_same_v) - { - return pad_tensor_view( - a_tensor_view, - make_tuple(number{}, number{}), - sequence{}); - } - else - { - return pad_tensor_view( - a_tensor_view, - make_tuple(number{}, number{}), - sequence{}); - } - }(); - // clang-format on - - auto a_block_window = make_tile_window( - a_pad_view, - make_tuple(number{}, number{}), - {i_m, 0}); - - auto b_pad_view = [&]() { - if constexpr(std::is_same_v) - { - return pad_tensor_view( - b_tensor_view, - make_tuple(number{}, number{}), - sequence{}); - } - else - { - return pad_tensor_view( - b_tensor_view, - make_tuple(number{}, number{}), - sequence{}); - } - }(); - // clang-format on - - auto b_block_window = make_tile_window( - b_pad_view, - make_tuple(number{}, number{}), - {i_n, 0}); - - // allocate LDS - __shared__ char smem_ptr[GetSmemSize()]; - - const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K); - - // Run GEMM cooperatively by whole wokrgroup. - auto c_block_tile = - GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr); + const BDataType* b_ptr = static_cast(kargs.b_ptr) + batch_offset_B; const auto batch_stride_C = __builtin_amdgcn_readfirstlane(kargs.batch_stride_C); const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_C); - CDataType* c_start = static_cast(kargs.c_ptr); - auto c_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - c_start + batch_offset_C, - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_C, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - c_start + batch_offset_C, - make_tuple(kargs.M, kargs.N), - make_tuple(1, kargs.stride_C), - number<1>{}, - number<1>{}); - } - }(); + CDataType* c_ptr = static_cast(kargs.c_ptr) + batch_offset_C; - auto c_pad_view = [&]() { - if constexpr(std::is_same_v) - { - return pad_tensor_view( - c_tensor_view, - make_tuple(number{}, number{}), - sequence{}); - } - else - { - return pad_tensor_view( - c_tensor_view, - make_tuple(number{}, number{}), - sequence{}); - } - }(); - auto c_block_window = make_tile_window( - c_pad_view, - make_tuple(number{}, number{}), - {i_m, i_n}); - - EpiloguePipeline{}(c_block_window, c_block_tile); + this->RunGemm(a_ptr, b_ptr, c_ptr, kargs, i_m, i_n); } }; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 763d8cad9c..925648a886 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -12,6 +12,50 @@ namespace ck_tile { +struct GemmProblem +{ + CK_TILE_HOST GemmProblem() = default; + CK_TILE_HOST GemmProblem( + index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_) + : M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_) + { + } + + index_t M; + index_t N; + index_t K; + index_t stride_A; + index_t stride_B; + index_t stride_C; +}; + +struct GemmHostArgs : public GemmProblem +{ + CK_TILE_HOST GemmHostArgs() = default; + CK_TILE_HOST GemmHostArgs(const void* a_ptr_, + const void* b_ptr_, + void* c_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + index_t stride_A_, + index_t stride_B_, + index_t stride_C_) + : GemmProblem(M_, N_, K_, stride_A_, stride_B_, stride_C_), + a_ptr(a_ptr_), + b_ptr(b_ptr_), + c_ptr(c_ptr_), + k_batch(k_batch_) + { + } + + const void* a_ptr; + const void* b_ptr; + void* c_ptr; + index_t k_batch; +}; + template struct GemmKernel { @@ -25,9 +69,12 @@ struct GemmKernel using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; - // using CAccDataType = remove_cvref_t; using CDataType = remove_cvref_t; + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + __host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) { return TilePartitioner::GridSize(M, N, KBatch); @@ -35,7 +82,7 @@ struct GemmKernel __host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); } - struct GemmCommonKargs + struct GemmKernelArgs { const void* a_ptr; const void* b_ptr; @@ -48,25 +95,37 @@ struct GemmKernel index_t stride_C; }; - CK_TILE_HOST static constexpr GemmCommonKargs MakeKargs(const void* a_ptr, - const void* b_ptr, - void* c_ptr, - index_t M, - index_t N, - index_t K, - index_t stride_A, - index_t stride_B, - index_t stride_C) + CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs) { - return GemmCommonKargs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C}; + return GemmKernelArgs{hostArgs.a_ptr, + hostArgs.b_ptr, + hostArgs.c_ptr, + hostArgs.M, + hostArgs.N, + hostArgs.K, + hostArgs.stride_A, + hostArgs.stride_B, + hostArgs.stride_C}; } + // CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const void* a_ptr, + // const void* b_ptr, + // void* c_ptr, + // index_t M, + // index_t N, + // index_t K, + // index_t stride_A, + // index_t stride_B, + // index_t stride_C) + // { + // return GemmKernelArgs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C}; + // } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); } - CK_TILE_HOST static bool IsSupportedArgument(const GemmCommonKargs& kargs) + CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs) { if constexpr(std::is_same_v) { @@ -139,18 +198,16 @@ struct GemmKernel return true; } - CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const + CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType* a_ptr, + const BDataType* b_ptr, + CDataType* c_ptr, + const GemmKernelArgs& kargs) const { - const auto [i_m, i_n] = TilePartitioner{}(); - // options - const ADataType* a_start = static_cast(kargs.a_ptr); - const BDataType* b_start = static_cast(kargs.b_ptr); - // Convert pointers to tensor views - auto a_tensor_view = [&]() { + const auto& a_tensor_view = [&]() { if constexpr(std::is_same_v) { return make_naive_tensor_view( - a_start, + a_ptr, make_tuple(kargs.M, kargs.K), make_tuple(kargs.stride_A, 1), number{}, @@ -159,7 +216,7 @@ struct GemmKernel else { return make_naive_tensor_view( - a_start, + a_ptr, make_tuple(kargs.M, kargs.K), make_tuple(1, kargs.stride_A), number<1>{}, @@ -167,11 +224,11 @@ struct GemmKernel } }(); - auto b_tensor_view = [&]() { + const auto& b_tensor_view = [&]() { if constexpr(std::is_same_v) { return make_naive_tensor_view( - b_start, + b_ptr, make_tuple(kargs.N, kargs.K), make_tuple(1, kargs.stride_B), number<1>{}, @@ -180,7 +237,7 @@ struct GemmKernel else { return make_naive_tensor_view( - b_start, + b_ptr, make_tuple(kargs.N, kargs.K), make_tuple(kargs.stride_B, 1), number{}, @@ -188,7 +245,35 @@ struct GemmKernel } }(); - auto a_pad_view = [&]() { + const auto& c_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + c_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.stride_C, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + c_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(1, kargs.stride_C), + number<1>{}, + number<1>{}); + } + }(); + + return make_tuple(a_tensor_view, b_tensor_view, c_tensor_view); + } + + template + CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView& views) const + { + const auto& a_pad_view = [&]() { + const auto& a_tensor_view = views.at(I0); if constexpr(std::is_same_v) { return pad_tensor_view( @@ -204,14 +289,9 @@ struct GemmKernel sequence{}); } }(); - // clang-format on - auto a_block_window = make_tile_window( - a_pad_view, - make_tuple(number{}, number{}), - {i_m, 0}); - - auto b_pad_view = [&]() { + const auto& b_pad_view = [&]() { + const auto& b_tensor_view = views.at(I1); if constexpr(std::is_same_v) { return pad_tensor_view( @@ -228,43 +308,8 @@ struct GemmKernel } }(); - auto b_block_window = make_tile_window( - b_pad_view, - make_tuple(number{}, number{}), - {i_n, 0}); - - // allocate LDS - __shared__ char smem_ptr[GetSmemSize()]; - - const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K); - - // Run GEMM cooperatively by whole wokrgroup. - auto c_block_tile = - GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr); - - CDataType* c_start = static_cast(kargs.c_ptr); - auto c_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - c_start, - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_C, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - c_start, - make_tuple(kargs.M, kargs.N), - make_tuple(1, kargs.stride_C), - number<1>{}, - number<1>{}); - } - }(); - - auto c_pad_view = [&]() { + const auto& c_pad_view = [&]() { + const auto& c_tensor_view = views.at(I2); if constexpr(std::is_same_v) { return pad_tensor_view( @@ -280,12 +325,82 @@ struct GemmKernel sequence{}); } }(); - auto CBlockWindow_pad = make_tile_window( + + return make_tuple(a_pad_view, b_pad_view, c_pad_view); + } + + template + CK_TILE_DEVICE auto + MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) const + { + const auto& a_pad_view = views.at(I0); + const auto& a_block_window = make_tile_window( + a_pad_view, + make_tuple(number{}, number{}), + {i_m, 0}); + + const auto& b_pad_view = views.at(I1); + const auto& b_block_window = make_tile_window( + b_pad_view, + make_tuple(number{}, number{}), + {i_n, 0}); + + const auto& c_pad_view = views.at(I2); + auto c_block_window = make_tile_window( c_pad_view, make_tuple(number{}, number{}), {i_m, i_n}); - EpiloguePipeline{}(CBlockWindow_pad, c_block_tile); + return make_tuple(a_block_window, b_block_window, c_block_window); + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @param a_ptr input A pointer + * @param b_ptr input B pointer + * @param c_ptr output C pointer + * @param kargs GEMM kernel arguments + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + */ + CK_TILE_DEVICE void RunGemm(const ADataType* a_ptr, + const BDataType* b_ptr, + CDataType* c_ptr, + const GemmKernelArgs& kargs, + const index_t block_idx_m, + const index_t block_idx_n) const + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs); + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K); + + // Run GEMM cooperatively by whole workgroup. + const auto& a_block_window = gemm_tile_windows.at(I0); + const auto& b_block_window = gemm_tile_windows.at(I1); + const auto& c_block_tile = + GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I2); + EpiloguePipeline{}(c_block_window, c_block_tile); + } + + CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const + { + const auto [i_m, i_n] = TilePartitioner{}(); + // options + const ADataType* a_ptr = static_cast(kargs.a_ptr); + const BDataType* b_ptr = static_cast(kargs.b_ptr); + CDataType* c_ptr = static_cast(kargs.c_ptr); + + RunGemm(a_ptr, b_ptr, c_ptr, kargs, i_m, i_n); } }; diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp index 88145b987b..d3f3077870 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp +++ b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp @@ -24,12 +24,9 @@ class TestCkTileBatchedGemm : public ::testing::Test using AccDataType = std::tuple_element_t<5, Tuple>; using CDataType = std::tuple_element_t<6, Tuple>; - struct batched_gemm_kargs : public ck_tile::BatchedGemmHostArgs - { - }; - template - void invoke_batched_gemm(const batched_gemm_kargs& args, const ck_tile::stream_config& s) + void invoke_batched_gemm(const ck_tile::BatchedGemmHostArgs& args, + const ck_tile::stream_config& s) { // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. constexpr bool kPadM = false; @@ -94,9 +91,9 @@ class TestCkTileBatchedGemm : public ::testing::Test using Kernel = ck_tile::BatchedGemmKernel; - auto kargs = Kernel::MakeKargs(args); + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count); constexpr dim3 blocks = Kernel::BlockSize(); if(s.log_level_ > 0) @@ -185,21 +182,22 @@ class TestCkTileBatchedGemm : public ::testing::Test c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - batched_gemm_kargs kargs{a_m_k_dev_buf.GetDeviceBuffer(), - b_k_n_dev_buf.GetDeviceBuffer(), - c_m_n_dev_buf.GetDeviceBuffer(), - M, - N, - K, - StrideA, - StrideB, - StrideC, - BatchStrideA, - BatchStrideB, - BatchStrideC, - BatchCount}; + ck_tile::BatchedGemmHostArgs args; + args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); + args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); + args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); + args.M = M; + args.N = N; + args.K = K; + args.stride_A = StrideA; + args.stride_B = StrideB; + args.stride_C = StrideC; + args.batch_stride_A = BatchStrideA; + args.batch_stride_B = BatchStrideB; + args.batch_stride_C = BatchStrideC; + args.batch_count = BatchCount; - invoke_batched_gemm(kargs, + invoke_batched_gemm(args, ck_tile::stream_config{nullptr, false}); std::cout << "Run kernel with M =" << M << " N =" << N << " K =" << K diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index a514986024..53ead4d8d6 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -31,22 +31,8 @@ class TestCkTileGemmPipeline : public ::testing::Test static constexpr auto PipelineType = std::tuple_element_t<8, Tuple>::value; // TODO: expose tile size through test t-param ? - struct gemm_args - { - const void* p_a; - const void* p_b; - void* p_c; - ck_tile::index_t kbatch; - ck_tile::index_t M; - ck_tile::index_t N; - ck_tile::index_t K; - ck_tile::index_t stride_A; - ck_tile::index_t stride_B; - ck_tile::index_t stride_C; - }; - template - void invoke_gemm(const gemm_args& args, const ck_tile::stream_config& s) + void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { // TODO: This should be parameterized in tests constexpr ck_tile::index_t M_Tile = 128; @@ -117,17 +103,9 @@ class TestCkTileGemmPipeline : public ::testing::Test has_hot_loop_v, tail_number_v>>>; using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKargs(args.p_a, - args.p_b, - args.p_c, - args.M, - args.N, - args.K, - args.stride_A, - args.stride_B, - args.stride_C); + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); constexpr dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) @@ -319,11 +297,11 @@ class TestCkTileGemmPipeline : public ::testing::Test c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - gemm_args args; - args.p_a = a_m_k_dev_buf.GetDeviceBuffer(); - args.p_b = b_k_n_dev_buf.GetDeviceBuffer(); - args.p_c = c_m_n_dev_buf.GetDeviceBuffer(); - args.kbatch = kbatch; + ck_tile::GemmHostArgs args; + args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); + args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); + args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); + args.k_batch = kbatch; args.M = M; args.N = N; args.K = K;