From 0c0fd440caf54d6a8d00de03f612c59c3b62ea59 Mon Sep 17 00:00:00 2001 From: Mateusz Ozga <110818320+mozga-amd@users.noreply.github.com> Date: Thu, 24 Jul 2025 20:39:56 +0200 Subject: [PATCH] [CK_TILE] Introduces a new GEMM API that splits the existing basic GEMM class into multiple specialized classes. (#2520) * Init commit new API * apply clang-format * PreShuffle preapring * Apply Preshuffle condition to universal_gemm * Fix: convert size_t to index_t * Review changes * Mode 100755 -> 100644 --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> [ROCm/composable_kernel commit: b507d889c11b099004f94b1402d0693c3942234c] --- example/ck_tile/03_gemm/gemm_basic.cpp | 2 +- example/ck_tile/03_gemm/gemm_utils.hpp | 2 +- .../03_gemm/gemm_weight_preshuffle.cpp | 209 +-- example/ck_tile/03_gemm/run_gemm_example.inc | 24 +- example/ck_tile/03_gemm/universal_gemm.cpp | 209 +-- .../run_batched_gemm_example.inc | 29 +- .../ck_tile/17_grouped_gemm/grouped_gemm.hpp | 2 +- .../run_grouped_gemm_example.inc | 26 +- .../19_gemm_multi_d/gemm_multi_d_fp16.cpp | 2 +- .../19_gemm_multi_d/gemm_multi_d_fp16.hpp | 2 +- include/ck_tile/core/container/tuple.hpp | 2 + include/ck_tile/ops/gemm.hpp | 2 + .../ops/gemm/kernel/batched_gemm_kernel.hpp | 166 ++- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 1015 +------------- .../ops/gemm/kernel/gemm_multi_d_kernel.hpp | 185 +++ .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 165 ++- .../ops/gemm/kernel/universal_gemm_kernel.hpp | 1169 +++++++++++++++++ .../batched_gemm/test_batched_gemm_util.hpp | 29 +- .../test_gemm_pipeline_basic_run_test.inc | 2 +- .../test_gemm_pipeline_smoke_run_test.inc | 24 +- .../gemm/test_gemm_pipeline_smoke_util.hpp | 2 +- .../test_gemm_pipeline_universal_run_test.inc | 211 +-- test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 30 +- .../gemm_multi_d/test_gemm_multi_d_util.hpp | 30 +- .../test_gemm_pipeline_util.hpp | 30 +- .../grouped_gemm/test_grouped_gemm_util.hpp | 28 +- tile_engine/ops/gemm/gemm_instance_builder.py | 10 +- tile_engine/ops/gemm/gemm_profiler.hpp | 6 +- 28 files changed, 2094 insertions(+), 1519 deletions(-) create mode 100644 include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp create mode 100644 include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 80c18cdb87..0d9c2d9957 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -24,7 +24,7 @@ template -float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { if constexpr(Persistent) diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 24f64994cf..1e867afd1a 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -475,4 +475,4 @@ template -float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp index b7b0701080..34333d5474 100644 --- a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp @@ -25,7 +25,7 @@ template -float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { using GemmShape = ck_tile::TileGemmShape< @@ -74,119 +74,120 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile: const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; - const auto Run = - [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - dim3 grids; - if constexpr(Persistent) - { - grids = Kernel::MaxOccupancyGridSize(s); - } - else - { - grids = Kernel::GridSize(args.M, args.N, args.k_batch); - } - constexpr dim3 blocks = Kernel::BlockSize(); + dim3 grids; + if constexpr(Persistent) + { + grids = Kernel::MaxOccupancyGridSize(s); + } + else + { + grids = Kernel::GridSize(args.M, args.N, args.k_batch); + } + constexpr dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - static constexpr ck_tile::index_t APackedSize = - std::is_same_v ? 2 : 1; - static constexpr ck_tile::index_t BPackedSize = - std::is_same_v ? 2 : 1; + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + static constexpr ck_tile::index_t APackedSize = + std::is_same_v ? 2 : 1; + static constexpr ck_tile::index_t BPackedSize = + std::is_same_v ? 2 : 1; - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; - auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; + auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; + auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; - ck_tile::RotatingMemWrapper rotating_mem( - kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); - rotating_mem.Print(); + ck_tile::RotatingMemWrapper rotating_mem( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); - auto run_flush_cache = [&]() { - // flush icache - ck_tile::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); - }; - ave_time = ck_tile::launch_kernel_preprocess( - s, - run_flush_cache, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); - } - else - { - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); - } - return ave_time; - }; + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + ave_time = ck_tile::launch_kernel_preprocess( + s, + run_flush_cache, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + else + { + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + return ave_time; + }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 83836117e9..7f87c2bc06 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -158,7 +158,7 @@ template -float gemm(const ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& s); +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); template args = {a_m_k_dev_buf.GetDeviceBuffer(), - b_k_n_dev_buf.GetDeviceBuffer(), - {}, - c_m_n_dev_buf.GetDeviceBuffer(), - kbatch, - M, - N, - K, - stride_A, - stride_B, - {}, - stride_C}; + ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + stride_C}; float ave_time; if(persistent) diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index c96a470910..6c60f98fa4 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -25,7 +25,7 @@ template -float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { using GemmShape = ck_tile::TileGemmShape< @@ -74,120 +74,121 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile: const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; - const auto Run = - [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - dim3 grids; - if constexpr(Persistent) - { - grids = Kernel::MaxOccupancyGridSize(s); - } - else - { - grids = Kernel::GridSize(args.M, args.N, args.k_batch); - } - constexpr dim3 blocks = Kernel::BlockSize(); + dim3 grids; + if constexpr(Persistent) + { + grids = Kernel::MaxOccupancyGridSize(s); + } + else + { + grids = Kernel::GridSize(args.M, args.N, args.k_batch); + } + constexpr dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - static constexpr ck_tile::index_t APackedSize = - std::is_same_v ? 2 : 1; - static constexpr ck_tile::index_t BPackedSize = - std::is_same_v ? 2 : 1; + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + static constexpr ck_tile::index_t APackedSize = + std::is_same_v ? 2 : 1; + static constexpr ck_tile::index_t BPackedSize = + std::is_same_v ? 2 : 1; - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; - auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; + auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; + auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; - ck_tile::RotatingMemWrapper rotating_mem( - kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); - rotating_mem.Print(); + ck_tile::RotatingMemWrapper rotating_mem( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); - auto run_flush_cache = [&]() { - // flush icache - ck_tile::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); - }; - ave_time = ck_tile::launch_kernel_preprocess( - s, - run_flush_cache, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); - } - else - { - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); - } - return ave_time; - }; + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + ave_time = ck_tile::launch_kernel_preprocess( + s, + run_flush_cache, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + else + { + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + return ave_time; + }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) diff --git a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc index 7d5e1910dd..6d26cfe675 100644 --- a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc +++ b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc @@ -50,21 +50,20 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, int n_warmup, int n_repeat) { - ck_tile::BatchedGemmHostArgs args; - args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); - args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); - args.e_ptr = c_m_n_dev_buf.GetDeviceBuffer(); - args.k_batch = kbatch; - args.M = M; - args.N = N; - args.K = K; - args.stride_A = stride_A; - args.stride_B = stride_B; - args.stride_E = stride_C; - args.batch_stride_A = batch_stride_A; - args.batch_stride_B = batch_stride_B; - args.batch_stride_E = batch_stride_C; - args.batch_count = batch_count; + ck_tile::BatchedGemmHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + stride_C, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_count}; float ave_time = batched_gemm; +using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; auto create_args(int argc, char* argv[]) { diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index 5ed1219731..7532923f9a 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -83,18 +83,18 @@ float invoke_gemm(int n_warmup, const bool splitk = args[0].k_batch > 1; for(const auto& arg : args) { - kargs.emplace_back(ck_tile::GemmKernelArgs<>{arg.a_ptr, - arg.b_ptr, - {}, - arg.e_ptr, - arg.M, - arg.N, - arg.K, - arg.stride_A, - arg.stride_B, - {}, - arg.stride_E, - arg.k_batch}); + kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<>{{arg.a_ptr}, + {arg.b_ptr}, + {/*arg.ds_ptr*/}, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + {arg.stride_A}, + {arg.stride_B}, + {/*arg.stride_Ds*/}, + arg.stride_E, + arg.k_batch}); } const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, @@ -240,7 +240,7 @@ int run_grouped_gemm_example_with_layouts(int argc, void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); gemm_descs.push_back( - {p_a, p_b, {}, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], {}, stride_Cs[i]}); + {p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); } invoke_gemm>; - using Kernel = ck_tile::GemmKernel; + using Kernel = ck_tile::GemmKernelMultiD; auto kargs = Kernel::MakeKernelArgs(args); const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); diff --git a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp index 3ce3965e56..87b9592553 100644 --- a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp +++ b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp @@ -64,7 +64,7 @@ auto create_args(int argc, char* argv[]) return std::make_tuple(result, arg_parser); } -using gemm_multi_d_kargs = ck_tile::GemmHostArgs; +using gemm_multi_d_kargs = ck_tile::GemmMultiDHostArgs; template , T...> return flag; } + CK_TILE_HOST_DEVICE static constexpr bool IsTuple() { return true; } + #define TP_COM_() static_assert(I < size(), "wrong! out of range") // clang-format off template CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const & { TP_COM_(); return impl::getv(*this); } diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index b396f03244..9d00de5f73 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -28,6 +28,8 @@ #include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp" +#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index fc72138abf..9c1ce73eac 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -9,35 +9,41 @@ namespace ck_tile { -struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs +/// @brief The Batched GEMM kernel host arguments. +/// +/// @par Overview +/// This structure is passed to @ref BatchedGemmKernel "BatchedGemmKernel" when creating kernel +/// arguments object. It contain all necessary information required to build proper kernel +/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by +/// stating all required information like M,N,K sizes and respective strides. +struct BatchedGemmHostArgs : public ck_tile::UniversalGemmHostArgs<> { - CK_TILE_HOST BatchedGemmHostArgs() = default; - CK_TILE_HOST BatchedGemmHostArgs(const void* a_ptr_, - const void* b_ptr_, - void* c_ptr_, - ck_tile::index_t k_batch_, - ck_tile::index_t M_, - ck_tile::index_t N_, - ck_tile::index_t K_, - ck_tile::index_t stride_A_, - ck_tile::index_t stride_B_, - ck_tile::index_t stride_C_, - ck_tile::index_t batch_stride_A_, - ck_tile::index_t batch_stride_B_, - ck_tile::index_t batch_stride_C_, - ck_tile::index_t batch_count_) - : GemmHostArgs(a_ptr_, - b_ptr_, - {}, - c_ptr_, - k_batch_, - M_, - N_, - K_, - stride_A_, - stride_B_, - {}, - stride_C_), + CK_TILE_HOST explicit BatchedGemmHostArgs(const void* a_ptr_, + const void* b_ptr_, + void* c_ptr_, + ck_tile::index_t k_batch_, + ck_tile::index_t M_, + ck_tile::index_t N_, + ck_tile::index_t K_, + ck_tile::index_t stride_A_, + ck_tile::index_t stride_B_, + ck_tile::index_t stride_C_, + ck_tile::index_t batch_stride_A_, + ck_tile::index_t batch_stride_B_, + ck_tile::index_t batch_stride_C_, + ck_tile::index_t batch_count_) + : UniversalGemmHostArgs<>({a_ptr_}, + {b_ptr_}, + {/*ds_ptr*/}, + c_ptr_, + k_batch_, + M_, + N_, + K_, + {stride_A_}, + {stride_B_}, + {/*stride_Ds_*/}, + stride_C_), batch_stride_A(batch_stride_A_), batch_stride_B(batch_stride_B_), batch_stride_E(batch_stride_C_), @@ -52,36 +58,43 @@ struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs }; template -struct BatchedGemmKernel : public GemmKernel +struct BatchedGemmKernel { - using Base = GemmKernel; + /// @brief Inject the UniversalGemmKernel base class to support execution of all necessary + /// functions. + using UniversalGemmKernel = + UniversalGemmKernel; - using GemmKernelArgs = typename ck_tile::GemmKernelArgs<>; + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; - using ADataType = typename Base::ADataType; - using BDataType = typename Base::BDataType; - using CDataType = typename Base::EDataType; + /// @brief Specify the layout configurations for A, B, E and D + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; - using TilePartitioner = typename Base::TilePartitioner; - using GemmPipeline = typename Base::GemmPipeline; - using EpiloguePipeline = typename Base::EpiloguePipeline; - using ALayout = typename Base::ALayout; - using BLayout = typename Base::BLayout; - using CLayout = typename Base::ELayout; + /// @brief Specify the data type configurations for A, B, E and D + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; - [[nodiscard]] CK_TILE_HOST static const std::string GetName() - { - // clang-format off - using P_ = GemmPipeline; + /// @brief ALayout and ADataType are expected to be scalars, not a tuple. + static_assert( + !is_detected::value && !is_detected::value, + "ALayout and ADataType must be scalars. Multiple parameters are not currently supported."); - return concat('_', "gemm_batched", gemm_prec_str(), - concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock), - concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()), - concat('x', P_::kPadM, P_::kPadN, P_::kPadK)); - // clang-format on - } + /// @brief BLayout and BDataType are expected to be scalars, not a tuple. + static_assert( + !is_detected::value && !is_detected::value, + "BLayout and BDataType must be scalars. Multiple parameters are not currently supported."); - struct BatchedGemmKernelArgs : GemmKernelArgs + /// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && + !is_detected::value, + "C/ELayout and C/EDataType must be scalars."); + + struct BatchedGemmKernelArgs : ck_tile::UniversalGemmKernelArgs<> { index_t batch_stride_A; index_t batch_stride_B; @@ -91,27 +104,41 @@ struct BatchedGemmKernel : public GemmKernel const std::string + { + // clang-format off + using P_ = GemmPipeline; + return concat('_', "gemm_batched", gemm_prec_str(), + concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock), + concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()), + concat('x', P_::kPadM, P_::kPadN, P_::kPadK)); + // clang-format on + } + + CK_TILE_HOST static constexpr auto + GridSize(index_t M, index_t N, index_t KBatch, index_t batch_count) -> dim3 { return dim3(TilePartitioner::GridSize(M, N), batch_count, KBatch); } - __host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); } + CK_TILE_HOST static constexpr auto BlockSize() -> dim3 + { + return dim3(UniversalGemmKernel::KernelBlockSize); + } CK_TILE_HOST static constexpr BatchedGemmKernelArgs MakeKernelArgs(const BatchedGemmHostArgs& hostArgs) { - return BatchedGemmKernelArgs{{hostArgs.a_ptr, - hostArgs.b_ptr, - {}, + return BatchedGemmKernelArgs{{hostArgs.as_ptr, + hostArgs.bs_ptr, + hostArgs.ds_ptr, hostArgs.e_ptr, hostArgs.M, hostArgs.N, hostArgs.K, - hostArgs.stride_A, - hostArgs.stride_B, - {}, + hostArgs.stride_As, + hostArgs.stride_Bs, + hostArgs.stride_Ds, hostArgs.stride_E, hostArgs.k_batch}, hostArgs.batch_stride_A, @@ -125,6 +152,12 @@ struct BatchedGemmKernel : public GemmKernel bool + { + return UniversalGemmKernel::IsSupportedArgument(kargs); + } + CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const { const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x); @@ -134,18 +167,18 @@ struct BatchedGemmKernel : public GemmKernel(kargs.a_ptr) + batch_offset_A + - splitk_batch_offset.a_k_split_offset; + const ADataType* a_ptr = static_cast(kargs.as_ptr[0]) + batch_offset_A + + splitk_batch_offset.as_k_split_offset[0]; const auto batch_stride_B = __builtin_amdgcn_readfirstlane(kargs.batch_stride_B); const auto batch_offset_B = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_B); - const BDataType* b_ptr = static_cast(kargs.b_ptr) + batch_offset_B + - splitk_batch_offset.b_k_split_offset; + const BDataType* b_ptr = static_cast(kargs.bs_ptr[0]) + batch_offset_B + + splitk_batch_offset.bs_k_split_offset[0]; const auto batch_stride_E = __builtin_amdgcn_readfirstlane(kargs.batch_stride_E); const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_E); @@ -154,7 +187,8 @@ struct BatchedGemmKernel : public GemmKernelRunGemm(a_ptr, b_ptr, {}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); + UniversalGemmKernel::RunGemm( + {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } }; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 53c21b49f5..079d3972d1 100755 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -12,6 +12,7 @@ #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/stream_utils.hpp" #include "ck_tile/core/utility/env.hpp" +#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" #include "ck_tile/core/utility/type_traits.hpp" namespace ck_tile { @@ -24,14 +25,11 @@ namespace ck_tile { /// and launch kernel on GPU. /// This structure defines the GEMM problem configuration by stating all required information /// like M,N,K sizes and respective strides. -/// NumDTensor describes the number of D tensors. -template struct GemmHostArgs { CK_TILE_HOST GemmHostArgs() = default; CK_TILE_HOST GemmHostArgs(const void* a_ptr_, const void* b_ptr_, - const std::array& ds_ptr_, void* e_ptr_, index_t k_batch_, index_t M_, @@ -39,18 +37,15 @@ struct GemmHostArgs index_t K_, index_t stride_A_, index_t stride_B_, - const std::array& stride_Ds_, index_t stride_E_) : a_ptr(a_ptr_), b_ptr(b_ptr_), - ds_ptr(ds_ptr_), e_ptr(e_ptr_), M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), - stride_Ds(stride_Ds_), stride_E(stride_E_), k_batch(k_batch_) { @@ -58,18 +53,18 @@ struct GemmHostArgs const void* a_ptr; const void* b_ptr; - const std::array ds_ptr; union { void* e_ptr; void* c_ptr; }; + index_t M; index_t N; index_t K; index_t stride_A; index_t stride_B; - const std::array stride_Ds; + union { index_t stride_E; @@ -79,990 +74,96 @@ struct GemmHostArgs index_t k_batch; }; -/// @brief The GEMM kernel device arguments. -template -struct GemmKernelArgs -{ - /// @brief The A input tensor's pointer to device memory. - const void* a_ptr; - /// @brief The B input tensor's pointer to device memory. - const void* b_ptr; - /// @brief The Ds input tensor's pointer to device memory. - const std::array ds_ptr; - /// @brief The E output tensor's pointer to device memory. - void* e_ptr; - /// @brief GEMM's M dimension size. - index_t M; - /// @brief GEMM's N dimension size. - index_t N; - /// @brief GEMM's K dimension size. - index_t K; - /// @brief The distance between consecutive elements of non-contiguous dimension - /// (in memory) of A tensor. - index_t stride_A; - /// @brief The distance between consecutive elements of non-contiguous dimension - /// (in memory) of B tensor. - index_t stride_B; - /// @brief The distance between consecutive elements of non-contiguous dimension - /// (in memory) of Ds tensor. - std::array stride_Ds; - /// @brief The distance between consecutive elements of non-contiguous dimension - /// (in memory) of E tensor. - index_t stride_E; - index_t k_batch; -}; - -/// @brief The GEMM kernel template. -/// -/// @paragraph Overview Overview -/// This class provides the generic matrix multiplication kernel template. By semantic -/// division of GEMM algorithm into following parts we achieve flexible, versatile -/// and robust kernel implementation. -/// -/// @li @b Prolog - The start of GEMM kernel implementation in @ref operator() -/// function call operator" which determines the work scope of each workgroup. -/// @li @b GemmPipeline - The core part @a "heart" of matrix multiplication algorithm. -/// This is the place where each workgroup is loading data from global memory and -/// carrying out dot products. -/// @li @b Epilogue - The @a "final" part of matrix multiplication implementation -/// responsible for storing results to global memory. This is also the place where -/// any additional operator fusion may take place. -/// -/// Additionally both @ref GemmPipeline_ "GemmPipeline" and @ref EpiloguePipeline_ -/// "EpiloguePipeline" are parameterized with so called @a Policy which determines all -/// internal details of those functional parts. You can think of it like both gemm and -/// epilogue pipelines provides the control-flow logic controlled by policies. Moreover -/// the policy is responsible for definition of all necessary data layouts and thread's -/// work distribution. -/// -/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into the -/// output data tile to be calculated. It determines the workgroup to -/// data relationship (or in other words - which data would be -/// processed and calculated by which workgroup). -/// @tparam GemmPipeline_ The type of class which provides the core part of matrix -/// multiplication. This class should provide implementation of data -/// loading from global memory and performing block-wise matrix -/// multiplication. You can think of it as a work done by single -/// workgroup point of view. -/// @tparam EpiloguePipeline_ The type of class providing the final part of matrix -/// multiplication implementation. It is responsible for storing -/// results calculated by @ref GemmPipeline_ "GemmPipeline" to -/// the output E tensor in global memory. template struct GemmKernel { + /// @brief Inject the UniversalGemmKernel base class to support execution of all necessary + /// functions. + using UniversalGemmKernel = + UniversalGemmKernel; + using TilePartitioner = remove_cvref_t; using GemmPipeline = remove_cvref_t; using EpiloguePipeline = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - // TODO: GemmPipeline::CLayout -> GemmPipeline::ELayout will be changed for multi-ABD - using ELayout = remove_cvref_t; - using DsLayout = remove_cvref_t; - using DsDataType = remove_cvref_t; - static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; - // Get the persistent kernel if the pipeline has it available - struct has_persistent_kernel - { - template - using has_persistent_type = decltype(T::UsePersistentKernel); - - static constexpr bool value = []() { - if constexpr(is_detected{}) - return GemmPipeline::UsePersistentKernel; - else - return false; - }(); - }; - static constexpr bool PersistentKernel = has_persistent_kernel::value; + /// @brief Specify the layout configurations for A, B, E and D + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using ELayout = remove_cvref_t; + /// @brief Specify the data type configurations for A, B, E and D using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; - // Below type is actually accumulation data type - the output of block GEMM. using EDataType = remove_cvref_t; - static constexpr index_t NumDTensor = DsDataType::size(); + /// @brief ALayout and ADataType are expected to be scalars, not a tuple. + static_assert( + !is_detected::value && !is_detected::value, + "ALayout and ADataType must be scalars. Multiple parameters are not currently supported."); - static constexpr auto I0 = number<0>(); - static constexpr auto I1 = number<1>(); - static constexpr auto I2 = number<2>(); - static constexpr auto I3 = number<3>{}; + /// @brief BLayout and BDataType are expected to be scalars, not a tuple. + static_assert( + !is_detected::value && !is_detected::value, + "BLayout and BDataType must be scalars. Multiple parameters are not currently supported."); - static_assert(DsLayout::size() == DsDataType::size(), - "The size of DsLayout and DsDataType should be the same"); - using KernelArgs = GemmKernelArgs; + /// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && + !is_detected::value, + "C/ELayout and C/EDataType must be scalars."); - [[nodiscard]] CK_TILE_HOST static const std::string GetName() + static constexpr index_t NumATensor = 1; + static constexpr index_t NumBTensor = 1; + + CK_TILE_HOST static auto GetName() -> const std::string { - // clang-format off - return concat('_', "gemm", gemm_prec_str(), GemmPipeline::GetName()); - // clang-format on + return UniversalGemmKernel::GetName(); } - CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) + CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3 { - return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); + return UniversalGemmKernel::GridSize(M, N, KBatch); } - /** - * @brief Get the maximum occupancy grid size for the persistent kernel on the current device. - * @return The maximum occupancy grid size. - * @note This function queries the maximum occupancy of the kernel using - * `hipOccupancyMaxActiveBlocksPerMultiprocessor`. - */ CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 { - using Kernel = GemmKernel; - const auto kernel = kentry; - int occupancy; - hip_check_error( - hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0)); - const int grid_size = get_available_compute_units(s) * occupancy; - return dim3(grid_size, 1, 1); + return UniversalGemmKernel::MaxOccupancyGridSize(s); } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } - - CK_TILE_HOST static constexpr KernelArgs - MakeKernelArgs(const GemmHostArgs& hostArgs) + CK_TILE_HOST static constexpr auto BlockSize() -> dim3 { - - return KernelArgs{hostArgs.a_ptr, - hostArgs.b_ptr, - hostArgs.ds_ptr, - hostArgs.e_ptr, - hostArgs.M, - hostArgs.N, - hostArgs.K, - hostArgs.stride_A, - hostArgs.stride_B, - hostArgs.stride_Ds, - hostArgs.stride_E, - hostArgs.k_batch}; + return UniversalGemmKernel::BlockSize(); } - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + CK_TILE_HOST static constexpr auto MakeKernelArgs(const GemmHostArgs& hostArgs) -> + typename UniversalGemmKernel::KernelArgs { - return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + /// @brief Universal GEMM requires array objects and corresponding stride information for + /// matrices A, B. + return UniversalGemmKernel::MakeKernelArgs( + UniversalGemmHostArgs( + {hostArgs.a_ptr}, + {hostArgs.b_ptr}, + {/*hostArgs.ds_ptr*/}, + hostArgs.e_ptr, + hostArgs.k_batch, + hostArgs.M, + hostArgs.N, + hostArgs.K, + {hostArgs.stride_A}, + {hostArgs.stride_B}, + {/*hostArgs.stride_Ds*/}, + hostArgs.stride_E)); } - struct SplitKBatchOffset + CK_TILE_HOST static auto + IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool { - __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z) - { - constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); - const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1); - const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1); - - if constexpr(std::is_same_v) - { - a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); - } - else if constexpr(std::is_same_v) - { - a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_A); - } - - if constexpr(std::is_same_v) - { - b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_B); - } - else if constexpr(std::is_same_v) - { - b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); - } - - if(k_id < static_cast(kargs.k_batch - 1)) - { - splitted_k = __builtin_amdgcn_readfirstlane(KRead); - } - else - { - splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1)); - } - } - - index_t a_k_split_offset; - index_t b_k_split_offset; - index_t splitted_k; - }; - - CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs) - { - if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value) - { - if(kargs.k_batch != 1) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("Conditions not met for Kbatch >1 !"); - } - return false; - } - } - - if constexpr(std::is_same_v) - { - if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && - GemmPipeline::kPadK == false) // k_batch is extra compared to flatmm - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock " - "without padding!"); - } - return false; - } - if(kargs.K % GemmPipeline::GetVectorSizeA() != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!"); - } - return false; - } - } - else - { - if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR( - "Can't support M that is not a multiple of MPerBlock without padding!"); - } - return false; - } - if(kargs.M % GemmPipeline::GetVectorSizeA() != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!"); - } - return false; - } - } - - if constexpr(std::is_same_v) - { - if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR( - "Can't support N that is not a multiple of NPerBlock without padding!"); - } - return false; - } - if(kargs.N % GemmPipeline::GetVectorSizeB() != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!"); - } - return false; - } - } - else - { - if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && - GemmPipeline::kPadK == false) // again k_batch is extra compared to flatmm - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock " - "without padding!"); - } - return false; - } - if(kargs.K % GemmPipeline::GetVectorSizeB() != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!"); - } - return false; - } - } - - bool DTesnorIsValid = {true}; - static_for<0, NumDTensor, 1>{}([&](auto index) { - using DiLayout = remove_cvref_t>; - if(std::is_same_v == false) - { - DTesnorIsValid = false; - } - if constexpr(std::is_same_v) - { - if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of " - "NPerBlock without padding!"); - } - DTesnorIsValid = false; - } - if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!"); - } - DTesnorIsValid = false; - } - } - else - { - if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of " - "MPerBlock without padding!"); - } - DTesnorIsValid = false; - } - if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!"); - } - DTesnorIsValid = false; - } - } - }); - - if constexpr(std::is_same_v) - { - if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR( - "Can't support N that is not a multiple of NPerBlock without padding!"); - } - return false; - } - if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!"); - } - return false; - } - } - else - { - if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR( - "Can't support M that is not a multiple of MPerBlock without padding!"); - } - return false; - } - if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!"); - } - return false; - } - } - return DTesnorIsValid; + return UniversalGemmKernel::IsSupportedArgument(kargs); } - template - CK_TILE_DEVICE static auto - MakeGemmTensorViews(const ADataType* a_ptr, - const BDataType* b_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset) + CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void { - static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); - - const auto& a_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - a_ptr, - make_tuple(kargs.M, splitk_batch_offset.splitted_k), - make_tuple(kargs.stride_A, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - a_ptr, - make_tuple(splitk_batch_offset.splitted_k, kargs.M), - make_tuple(kargs.stride_A, 1), - number{}, - number<1>{}); - } - }(); - - const auto& b_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - if constexpr(TilePartitioner::BlockGemmShape::PermuteB) - { - constexpr index_t K1 = GemmPipeline::GetSmemPackB(); - const index_t K0 = splitk_batch_offset.splitted_k / K1; - constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); - const auto b_k0_n_k1_desc = - make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), - make_tuple(kargs.N * K1, K1, I1), - number{}, - number<1>{}); - const auto b_n_k_desc = transform_tensor_descriptor( - b_k0_n_k1_desc, - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(kargs.N)), - make_tuple(sequence<0, 2>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return make_tensor_view(b_ptr, b_n_k_desc); - } - else - { - return make_naive_tensor_view( - b_ptr, - make_tuple(splitk_batch_offset.splitted_k, kargs.N), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); - } - } - else - { - if constexpr(TilePartitioner::BlockGemmShape::PermuteB) - { - constexpr index_t K1 = GemmPipeline::GetSmemPackB(); - const index_t K0 = splitk_batch_offset.splitted_k / K1; - constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); - const auto b_k0_n_k1_desc = - make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), - make_tuple(kargs.N * K1, K1, I1), - number{}, - number<1>{}); - const auto b_n_k_desc = transform_tensor_descriptor( - b_k0_n_k1_desc, - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(kargs.N)), - make_tuple(sequence<0, 2>{}, sequence<1>{}), - make_tuple(sequence<1>{}, sequence<0>{})); - return make_tensor_view(b_ptr, b_n_k_desc); - } - else - { - if constexpr(GemmPipeline::Preshuffle) - { - index_t kFlatK = - GemmPipeline::BlockGemmShape::flatKPerWarp * - (splitk_batch_offset.splitted_k / - TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{})); - index_t kFlatN = kargs.N * kargs.K / kFlatK; - - return make_naive_tensor_view( - b_ptr, - make_tuple(kFlatN, kFlatK), - make_tuple(kFlatK, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - b_ptr, - make_tuple(kargs.N, splitk_batch_offset.splitted_k), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); - } - } - } - }(); - - const auto& ds_tensor_view = generate_tuple( - [&](auto i) { - using DiLayout = remove_cvref_t>; - using DDataType_ = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - static_cast(ds_ptr[i]), - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_Ds[i], 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - static_cast(ds_ptr[i]), - make_tuple(kargs.N, kargs.M), - make_tuple(kargs.stride_Ds[i], 1), - number{}, - number<1>{}); - } - }, - number{}); - - // TODO: enable vector write for C in ColMajor - const auto& e_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - e_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_E, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - e_ptr, - make_tuple(kargs.M, kargs.N), // arguments not matching with flatmm. - make_tuple(1, kargs.stride_E), - number<1>{}, - number<1>{}); - } - }(); - - return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, e_tensor_view); - } - - template - CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) - { - const auto& a_pad_view = [&]() { - const auto& a_tensor_view = views.at(I0); - if constexpr(std::is_same_v) - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - - const auto& b_flat_pad_view = views.at(I1); - - const auto& b_pad_view = [&]() { - const auto& b_tensor_view = views.at(I1); - if constexpr(std::is_same_v) - { - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - - const auto& ds_pad_view = generate_tuple( - [&](auto i) { - const auto& d_tensor_view = views.at(I2); - using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - }, - number{}); - - // TODO vector write in for C in ColMajor - const auto& e_pad_view = [&]() { - const auto& e_tensor_view = views.at(I3); - if constexpr(std::is_same_v) - { - return pad_tensor_view(e_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(e_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - if constexpr(GemmPipeline::Preshuffle) - { - // For flatmm, we need to use the flat B tensor view - return make_tuple(a_pad_view, b_flat_pad_view, ds_pad_view, e_pad_view); - } - else - { - return make_tuple(a_pad_view, b_pad_view, ds_pad_view, e_pad_view); - } - } - - template - CK_TILE_DEVICE static auto - MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) - { - const auto& a_pad_view = views.at(I0); - const auto& b_pad_view = views.at(I1); - const auto& ds_pad_view = views.at(I2); - const auto& e_pad_view = views.at(I3); - - const auto& a_block_window = [&]() { - if constexpr(std::is_same_v) - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {i_m, 0}); - } - else - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {0, i_m}); - } - }(); - - const auto& b_block_window = [&]() { - if constexpr(GemmPipeline::Preshuffle) - { - return make_tile_window( - b_pad_view, - make_tuple(number{}, - number{}), - {static_cast(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)), 0}); - } - else - { - if constexpr(std::is_same_v) - { - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {i_n, 0}); - } - else - { - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {0, i_n}); - } - } - }(); - - const auto ds_block_window = generate_tuple( - [&](auto i) { - using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return make_tile_window(ds_pad_view[i], - make_tuple(number{}, - number{}), - {i_m, i_n}); - } - else - { - return make_tile_window(ds_pad_view[i], - make_tuple(number{}, - number{}), - {i_n, i_m}); - } - }, - number{}); - - auto e_block_window = make_tile_window( - e_pad_view, - make_tuple(number{}, number{}), - {i_m, i_n}); - - return make_tuple(a_block_window, b_block_window, ds_block_window, e_block_window); - } - - /** - * @brief Runs single GEMM problem cooperatively by whole workgroup. - * - * @param a_ptr input A pointer - * @param b_ptr input B pointer - * @param ds_ptr input Ds pointer - * @param e_ptr output E pointer - * @param smem_ptr_0 The start memory pointer of the shared memory block. - * @param kargs GEMM kernel arguments - * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch. - * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. - * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. - * - */ - template - CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, - const BDataType* b_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - void* smem_ptr_0, - const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset, - const index_t block_idx_m, - const index_t block_idx_n) - { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); - - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - - const index_t num_loop = __builtin_amdgcn_readfirstlane( - TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); - - // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - - const auto& c_block_tile = GemmPipeline{}.template operator()( - a_block_window, b_block_window, num_loop, smem_ptr_0); - - if(UseDefaultScheduler || (get_warp_id() == 0)) - { - auto& c_block_window = gemm_tile_windows.at(I3); - - EpiloguePipeline{}.template - operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); - } - } - - /** - * @brief Runs single GEMM problem cooperatively by whole workgroup. - * - * @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism. - * - * @param a_ptr input A pointer - * @param b_ptr input B pointer - * @param ds_ptr input Ds pointer - * @param e_ptr output E pointer - * @param smem_ptr_0 The starting pointer of 1st shared memory block. - * @param smem_ptr_1 The starting pointer of 2nd shared memory block. - * @param kargs GEMM kernel arguments - * @param splitk_batch_offset Utility structure used to calculate k batch. - * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. - * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. - * - */ - CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr, - const BDataType* b_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - void* __restrict__ smem_ptr_0, - void* __restrict__ smem_ptr_1, - const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset, - const index_t block_idx_m, - const index_t block_idx_n) - { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); - - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - - const index_t num_loop = __builtin_amdgcn_readfirstlane( - TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); - - // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - - const auto& c_block_tile = GemmPipeline{}.template operator()( - a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1); - - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); - - EpiloguePipeline{}.template - operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); - } - - // Non-persistent kernel entry point - template > - CK_TILE_DEVICE void operator()(KernelArgs kargs) const - { - const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x); - const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId); - const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); - const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); - - const SplitKBatchOffset splitk_batch_offset(kargs); - - // options - const ADataType* a_ptr = - static_cast(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; - const BDataType* b_ptr = - static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; - - EDataType* e_ptr = static_cast(kargs.e_ptr); - - // allocate LDS - __shared__ char smem_ptr_0[GetSmemSize()]; - - if constexpr(GemmPipeline::DoubleSmemBuffer == true) - { - __shared__ char smem_ptr_1[GetSmemSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunGemm2LDS(a_ptr, - b_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - } - else - { - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1); - RunGemm(a_ptr, - b_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - } - } - - // Persistent kernel entry point - template , typename = void> - CK_TILE_DEVICE void operator()(KernelArgs kargs) const - { - const auto grid_size = __builtin_amdgcn_readfirstlane(get_grid_size()); - const auto num_tiles = - __builtin_amdgcn_readfirstlane(TilePartitioner::GridSize(kargs.M, kargs.N)); - const auto num_work = __builtin_amdgcn_readfirstlane(num_tiles * kargs.k_batch); - auto block_id = __builtin_amdgcn_readfirstlane(get_block_id()); - - while(block_id < num_work) - { - // Get the tile index for this block - const auto tile_idx = __builtin_amdgcn_readfirstlane(block_id % num_tiles); - const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx); - const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); - const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); - - // Get the SplitK offset for this block - const auto k_batch = __builtin_amdgcn_readfirstlane(block_id / num_tiles); - const SplitKBatchOffset splitk_batch_offset(kargs, k_batch); - const ADataType* a_ptr = - static_cast(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; - const BDataType* b_ptr = - static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; - EDataType* e_ptr = static_cast(kargs.e_ptr); - - // allocate LDS - __shared__ char smem_ptr_0[GetSmemSize()]; - // Run the GEMM - if constexpr(GemmPipeline::DoubleSmemBuffer == true) - { - __shared__ char smem_ptr_1[GetSmemSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == - memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunGemm2LDS(a_ptr, - b_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - } - else - { - if constexpr(!(EpiloguePipeline::MemoryOperation == - memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunGemm(a_ptr, - b_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - } - // Advance to the next work item - block_id += grid_size; - if(block_id >= num_work) - { - break; - } - } + UniversalGemmKernel{}.template operator()(kargs); } }; - } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp new file mode 100644 index 0000000000..34340008d4 --- /dev/null +++ b/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp @@ -0,0 +1,185 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/host/concat.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/host/stream_utils.hpp" +#include "ck_tile/core/utility/env.hpp" +#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +/// @brief The MultiD GEMM kernel host arguments. +/// +/// @par Overview +/// This structure is passed to @ref GemmKernelMultiD "GemmKernelMultiD" when creating kernel +/// arguments object. It contain all necessary information required to build proper kernel +/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by +/// stating all required information like M,N,K sizes and respective strides. NumDTensor +/// describes the number of D tensors. +template +struct GemmMultiDHostArgs +{ + CK_TILE_HOST GemmMultiDHostArgs() = default; + CK_TILE_HOST GemmMultiDHostArgs(const void* a_ptr_, + const void* b_ptr_, + const std::array& ds_ptr_, + void* e_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + index_t stride_A_, + index_t stride_B_, + const std::array& stride_Ds_, + index_t stride_E_) + : a_ptr(a_ptr_), + b_ptr(b_ptr_), + ds_ptr(ds_ptr_), + e_ptr(e_ptr_), + M(M_), + N(N_), + K(K_), + stride_A(stride_A_), + stride_B(stride_B_), + stride_Ds(stride_Ds_), + stride_E(stride_E_), + k_batch(k_batch_) + { + } + + const void* a_ptr; + const void* b_ptr; + const std::array ds_ptr; + union + { + void* e_ptr; + void* c_ptr; + }; + index_t M; + index_t N; + index_t K; + index_t stride_A; + index_t stride_B; + const std::array stride_Ds; + union + { + index_t stride_E; + index_t stride_C; + }; + + index_t k_batch; +}; + +template +struct GemmKernelMultiD +{ + /// @brief Inject the UniversalGemmKernel base class to support execution of all necessary + /// functions. + using UniversalGemmKernel = + UniversalGemmKernel; + + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + + /// @brief Specify the layout configurations for A, B, E and D + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using ELayout = remove_cvref_t; + using DsLayout = remove_cvref_t; + + /// @brief Specify the data type configurations for A, B, E and D + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using EDataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + + /// @brief ALayout and ADataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && + !is_detected::value, + "ALayout and ADataType must be scalars."); + + /// @brief BLayout and BDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && + !is_detected::value, + "BLayout and BDataType must be scalars."); + + /// @brief ELayout and EDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && + !is_detected::value, + "ELayout and EDataType must be scalars."); + + /// @brief DsLayout and DsDataType are expected to be tuple, not a scalar. + static_assert(is_detected::value && + is_detected::value && + DsLayout::size() == DsDataType::size() && DsLayout::size() > 0, + "DsLayout and DsDataType must be tuples and must have the same size."); + + /// @brief The sizes of NumATensor and NumBTensor have always been 1; the size of D is set by + /// the user." + static constexpr index_t NumATensor = 1; + static constexpr index_t NumBTensor = 1; + static constexpr index_t NumDTensor = DsDataType::size(); + + CK_TILE_HOST static auto GetName() -> const std::string + { + return UniversalGemmKernel::GetName(); + } + + CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3 + { + return UniversalGemmKernel::GridSize(M, N, KBatch); + } + + CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 + { + return UniversalGemmKernel::MaxOccupancyGridSize(s); + } + + CK_TILE_HOST static constexpr auto BlockSize() -> dim3 + { + return UniversalGemmKernel::BlockSize(); + } + + CK_TILE_HOST static constexpr auto + MakeKernelArgs(const GemmMultiDHostArgs& hostArgs) -> + typename UniversalGemmKernel::KernelArgs + { + /// @brief Universal GEMM requires array objects and corresponding stride information for + /// matrices A, B, and D. + return UniversalGemmKernel::MakeKernelArgs( + UniversalGemmHostArgs({hostArgs.a_ptr}, + {hostArgs.b_ptr}, + hostArgs.ds_ptr, + hostArgs.e_ptr, + hostArgs.k_batch, + hostArgs.M, + hostArgs.N, + hostArgs.K, + {hostArgs.stride_A}, + {hostArgs.stride_B}, + hostArgs.stride_Ds, + hostArgs.stride_E)); + } + + CK_TILE_HOST static auto + IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool + { + return UniversalGemmKernel::IsSupportedArgument(kargs); + } + + CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void + { + UniversalGemmKernel{}.template operator()(kargs); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 2605b1afbc..8716475869 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -16,37 +16,116 @@ namespace ck_tile { +/// @brief The Grouped GEMM kernel host arguments. +/// +/// @par Overview +/// This structure is passed to @ref GroupedGemmKernel "GroupedGemmKernel" when creating kernel +/// arguments object. It contain all necessary information required to build proper kernel +/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by +/// stating all required information like M,N,K sizes and respective strides. +struct GroupedGemmHostArgs +{ + CK_TILE_HOST GroupedGemmHostArgs(const void* a_ptr_, + const void* b_ptr_, + void* e_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + index_t stride_A_, + index_t stride_B_, + index_t stride_E_) + : a_ptr(a_ptr_), + b_ptr(b_ptr_), + e_ptr(e_ptr_), + M(M_), + N(N_), + K(K_), + stride_A(stride_A_), + stride_B(stride_B_), + stride_E(stride_E_), + k_batch(k_batch_) + { + } + + const void* a_ptr; + const void* b_ptr; + union + { + void* e_ptr; + void* c_ptr; + }; + + index_t M; + index_t N; + index_t K; + index_t stride_A; + index_t stride_B; + + union + { + index_t stride_E; + index_t stride_C; + }; + + index_t k_batch; +}; + struct GemmTransKernelArg { - GemmKernelArgs<> group_karg; + UniversalGemmKernelArgs<> group_karg; ck_tile::index_t block_start; ck_tile::index_t block_end; GemmTransKernelArg() = delete; - GemmTransKernelArg(GemmKernelArgs<>&& karg, index_t bl_start, index_t bl_end) + GemmTransKernelArg(UniversalGemmKernelArgs<>&& karg, index_t bl_start, index_t bl_end) : group_karg{karg}, block_start{bl_start}, block_end{bl_end} { } - GemmTransKernelArg(GemmKernelArgs<>&& karg) : group_karg{karg}, block_start{0}, block_end{0} {} + GemmTransKernelArg(UniversalGemmKernelArgs<>&& karg) + : group_karg{karg}, block_start{0}, block_end{0} + { + } }; template -struct GroupedGemmKernel : public GemmKernel +struct GroupedGemmKernel { + /// @brief Inject the UniversalGemmKernel base class to support execution of all necessary + /// functions. + using Base = UniversalGemmKernel; + using TilePartitioner = remove_cvref_t; using GemmPipeline = remove_cvref_t; using EpiloguePipeline = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using ELayout = remove_cvref_t; + //// @brief Specify the layout configurations for A, B, C/E + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + /// @brief Specify the data type configurations for A, B, C/E using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using CDataType = remove_cvref_t; + /// @brief ALayout and ADataType are expected to be scalars, not a tuple. + static_assert( + !is_detected::value && !is_detected::value, + "ALayout and ADataType must be scalars. Multiple parameters are not currently supported."); + + /// @brief BLayout and BDataType are expected to be scalars, not a tuple. + static_assert( + !is_detected::value && !is_detected::value, + "BLayout and BDataType must be scalars. Multiple parameters are not currently supported."); + + /// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && + !is_detected::value, + "C/ELayout and C/EDataType must be scalars."); + using OffsetTile1DPartitioner = OffsettedTile1DPartitioner; - using Base = GemmKernel; using Kernel = GroupedGemmKernel; static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; @@ -65,8 +144,8 @@ struct GroupedGemmKernel : public GemmKernel>& gemm_descs) -> std::size_t + CK_TILE_HOST static auto GetWorkSpaceSize(const std::vector& gemm_descs) + -> std::size_t { return gemm_descs.size() * sizeof(GemmTransKernelArg); } @@ -95,8 +174,7 @@ struct GroupedGemmKernel : public GemmKernel>& gemm_descs) + CK_TILE_HOST static auto GridSize(const std::vector& gemm_descs) { index_t grid_size = 0; for(const auto& it_desc : gemm_descs) @@ -107,8 +185,7 @@ struct GroupedGemmKernel : public GemmKernel>& gemm_descs) + CK_TILE_HOST static auto MakeKargs(const std::vector& gemm_descs) -> std::vector { std::vector gemm_kernel_args_; @@ -138,18 +215,19 @@ struct GroupedGemmKernel : public GemmKernel{type_convert(gemm_descs[i].a_ptr), - type_convert(gemm_descs[i].b_ptr), - {}, - type_convert(gemm_descs[i].e_ptr), - M, - N, - K, - stride_a, - stride_b, - {}, - stride_e, - gemm_descs[i].k_batch}; + auto karg = + UniversalGemmKernelArgs<>{{type_convert(gemm_descs[i].a_ptr)}, + {type_convert(gemm_descs[i].b_ptr)}, + {/*ds_ptr*/}, + type_convert(gemm_descs[i].e_ptr), + M, + N, + K, + {stride_a}, + {stride_b}, + {/*stride_ds*/}, + stride_e, + gemm_descs[i].k_batch}; gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end); } @@ -181,7 +259,7 @@ struct GroupedGemmKernel : public GemmKernel& kargs, + CK_TILE_DEVICE void Run(const UniversalGemmKernelArgs<>& kargs, const tuple& block_idx_2d, const index_t block_idx_z) const { @@ -192,10 +270,10 @@ struct GroupedGemmKernel : public GemmKernel(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; - const BDataType* b_ptr = - static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; + const ADataType* a_ptr = static_cast(kargs.as_ptr[0]) + + splitk_batch_offset.as_k_split_offset[0]; + const BDataType* b_ptr = static_cast(kargs.bs_ptr[0]) + + splitk_batch_offset.bs_k_split_offset[0]; CDataType* c_ptr = static_cast(kargs.e_ptr); // allocate LDS @@ -208,7 +286,15 @@ struct GroupedGemmKernel : public GemmKernelRunGemm(a_ptr, b_ptr, {}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); + Base::RunGemm({a_ptr}, + {b_ptr}, + {/*ds_ptr*/}, + c_ptr, + smem_ptr, + kargs, + splitk_batch_offset, + i_m, + i_n); } } @@ -224,7 +310,8 @@ struct GroupedGemmKernel : public GemmKernel& kargs, + const UniversalGemmKernelArgs<>& kargs, const typename Base::SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n) @@ -242,7 +329,7 @@ struct GroupedGemmKernel : public GemmKernel( - a_ptr, b_ptr, {}, c_ptr, kargs, splitk_batch_offset); + {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset); const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = @@ -258,8 +345,12 @@ struct GroupedGemmKernel : public GemmKernel +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/host/concat.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/host/stream_utils.hpp" +#include "ck_tile/core/utility/env.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +/// @brief The Universal GEMM kernel host arguments. +/// +/// @par Overview +/// This structure is passed to @ref UniversalGemmKernel "UniversalGemmKernel" when creating +/// kernel arguments object. It contain all necessary information required to build proper +/// kernel argument and launch kernel on GPU. This structure defines the GEMM problem +/// configuration by stating all required information like M,N,K sizes and respective strides. +/// NumATensor describes the number of A tensors. The minimum number of tensors is 1(required). +/// NumBTensor describes the number of B tensors. The minimum number of tensors is 1(required). +/// NumDTensor describes the number of D tensors. The minimum number of tensors is 0(not +/// required). +template +struct UniversalGemmHostArgs +{ + CK_TILE_HOST UniversalGemmHostArgs(const std::array& as_ptr_, + const std::array& bs_ptr_, + const std::array& ds_ptr_, + void* e_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + const std::array& stride_As_, + const std::array& stride_Bs_, + const std::array& stride_Ds_, + index_t stride_E_) + : as_ptr(as_ptr_), + bs_ptr(bs_ptr_), + ds_ptr(ds_ptr_), + e_ptr(e_ptr_), + M(M_), + N(N_), + K(K_), + stride_As(stride_As_), + stride_Bs(stride_Bs_), + stride_Ds(stride_Ds_), + stride_E(stride_E_), + k_batch(k_batch_) + { + } + + const std::array as_ptr; + const std::array bs_ptr; + const std::array ds_ptr; + union + { + void* e_ptr; + void* c_ptr; + }; + index_t M; + index_t N; + index_t K; + const std::array stride_As; + const std::array stride_Bs; + const std::array stride_Ds; + union + { + index_t stride_E; + index_t stride_C; + }; + + index_t k_batch; +}; + +/// @brief The GEMM kernel device arguments. +template +struct UniversalGemmKernelArgs +{ + /// @brief The As input tensor's pointer to device memory. + const std::array as_ptr; + /// @brief The Bs input tensor's pointer to device memory. + const std::array bs_ptr; + /// @brief The Ds input tensor's pointer to device memory. + const std::array ds_ptr; + /// @brief The E output tensor's pointer to device memory. + void* e_ptr; + /// @brief GEMM's M dimension size. + index_t M; + /// @brief GEMM's N dimension size. + index_t N; + /// @brief GEMM's K dimension size. + index_t K; + /// @brief The distance between consecutive elements of non-contiguous dimension + /// (in memory) of As tensor. + std::array stride_As; + /// @brief The distance between consecutive elements of non-contiguous dimension + /// (in memory) of Bs tensor. + std::array stride_Bs; + /// @brief The distance between consecutive elements of non-contiguous dimension + /// (in memory) of Ds tensor. + std::array stride_Ds; + /// @brief The distance between consecutive elements of non-contiguous dimension + /// (in memory) of E tensor. + index_t stride_E; + index_t k_batch; +}; + +/// @brief The Universal GEMM kernel template. +/// +/// @paragraph Overview Overview +/// This class provides the generic matrix multiplication kernel template. By semantic +/// division of GEMM algorithm into following parts we achieve flexible, versatile +/// and robust kernel implementation. +/// +/// @li @b Prolog - The start of GEMM kernel implementation in @ref operator() +/// function call operator" which determines the work scope of each workgroup. +/// @li @b GemmPipeline - The core part @a "heart" of matrix multiplication algorithm. +/// This is the place where each workgroup is loading data from global memory and +/// carrying out dot products. +/// @li @b Epilogue - The @a "final" part of matrix multiplication implementation +/// responsible for storing results to global memory. This is also the place where +/// any additional operator fusion may take place. +/// +/// Additionally both @ref GemmPipeline_ "GemmPipeline" and @ref EpiloguePipeline_ +/// "EpiloguePipeline" are parameterized with so called @a Policy which determines all +/// internal details of those functional parts. You can think of it like both gemm and +/// epilogue pipelines provides the control-flow logic controlled by policies. Moreover +/// the policy is responsible for definition of all necessary data layouts and thread's +/// work distribution. +/// +/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into the +/// output data tile to be calculated. It determines the workgroup to +/// data relationship (or in other words - which data would be +/// processed and calculated by which workgroup). +/// @tparam GemmPipeline_ The type of class which provides the core part of matrix +/// multiplication. This class should provide implementation of data +/// loading from global memory and performing block-wise matrix +/// multiplication. You can think of it as a work done by single +/// workgroup point of view. +/// @tparam EpiloguePipeline_ The type of class providing the final part of matrix +/// multiplication implementation. It is responsible for storing +/// results calculated by @ref GemmPipeline_ "GemmPipeline" to +/// the output E tensor in global memory. +template +struct UniversalGemmKernel +{ + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + + static constexpr bool ADataTypeIsTuple = + is_detected::value; + static constexpr bool BDataTypeIsTuple = + is_detected::value; + static constexpr bool DDataTypeIsTuple = + is_detected::value; + static constexpr bool ALayoutIsTuple = + is_detected::value; + static constexpr bool BLayoutIsTuple = + is_detected::value; + static constexpr bool DLayoutIsTuple = + is_detected::value; + + using AsLayout = std::conditional_t, + remove_cvref_t>>; + using BsLayout = std::conditional_t, + remove_cvref_t>>; + + using DsLayout = std::conditional_t, + remove_cvref_t>>; + + using AsDataType = std::conditional_t, + remove_cvref_t>>; + + using BsDataType = std::conditional_t, + remove_cvref_t>>; + + using DsDataType = + std::conditional_t, + remove_cvref_t>>; + + using ELayout = remove_cvref_t; + using EDataType = remove_cvref_t; + + static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + + // Get the persistent kernel if the pipeline has it available + struct has_persistent_kernel + { + template + using has_persistent_type = decltype(T::UsePersistentKernel); + + static constexpr bool value = []() { + if constexpr(is_detected{}) + return GemmPipeline::UsePersistentKernel; + else + return false; + }(); + }; + static constexpr bool PersistentKernel = has_persistent_kernel::value; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto I3 = number<3>{}; + + static constexpr index_t NumATensor = AsDataType::size(); + static constexpr index_t NumBTensor = BsDataType::size(); + static constexpr index_t NumDTensor = DsDataType::size(); + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; + + static_assert(AsLayout::size() == AsDataType::size(), + "The size of AsLayout and AsDataType should be the same"); + + static_assert(BsLayout::size() == BsDataType::size(), + "The size of BsLayout and BsDataType should be the same"); + + static_assert(DsLayout::size() == DsDataType::size(), + "The size of DsLayout and DsDataType should be the same"); + + using KernelArgs = + UniversalGemmKernelArgs; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "gemm", gemm_prec_str(), GemmPipeline::GetName()); + // clang-format on + } + + CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) + { + return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); + } + + /** + * @brief Get the maximum occupancy grid size for the persistent kernel on the current device. + * @return The maximum occupancy grid size. + * @note This function queries the maximum occupancy of the kernel using + * `hipOccupancyMaxActiveBlocksPerMultiprocessor`. + */ + CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 + { + using Kernel = UniversalGemmKernel; + const auto kernel = kentry; + int occupancy; + hip_check_error( + hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0)); + const int grid_size = get_available_compute_units(s) * occupancy; + return dim3(grid_size, 1, 1); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + + CK_TILE_HOST static constexpr KernelArgs + MakeKernelArgs(const UniversalGemmHostArgs& hostArgs) + { + return KernelArgs{hostArgs.as_ptr, + hostArgs.bs_ptr, + hostArgs.ds_ptr, + hostArgs.e_ptr, + hostArgs.M, + hostArgs.N, + hostArgs.K, + hostArgs.stride_As, + hostArgs.stride_Bs, + hostArgs.stride_Ds, + hostArgs.stride_E, + hostArgs.k_batch}; + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z) + { + constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); + const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1); + const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1); + + static_for<0, NumATensor, 1>{}([&](auto index) { + using AiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + as_k_split_offset[index] = __builtin_amdgcn_readfirstlane(k_id * KRead); + } + else if constexpr(std::is_same_v) + { + as_k_split_offset[index] = + __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_As[index]); + } + }); + + static_for<0, NumBTensor, 1>{}([&](auto index) { + using BiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + bs_k_split_offset[index] = + __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_Bs[index]); + } + else if constexpr(std::is_same_v) + { + bs_k_split_offset[index] = __builtin_amdgcn_readfirstlane(k_id * KRead); + } + }); + + if(k_id < static_cast(kargs.k_batch - 1)) + { + splitted_k = __builtin_amdgcn_readfirstlane(KRead); + } + else + { + splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1)); + } + } + + std::array as_k_split_offset; + std::array bs_k_split_offset; + index_t splitted_k; + }; + + CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs) + { + if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value) + { + if(kargs.k_batch != 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Conditions not met for Kbatch >1 !"); + } + return false; + } + } + + bool AsTesnorIsValid = {true}; + static_for<0, NumATensor, 1>{}([&](auto index) { + using AiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && + GemmPipeline::kPadK == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support K that is not a multiple of k_batch * KPerBlock " + "without padding!"); + } + AsTesnorIsValid = false; + } + if(kargs.K % GemmPipeline::GetVectorSizeA() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!"); + } + AsTesnorIsValid = false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support M that is not a multiple of MPerBlock without padding!"); + } + AsTesnorIsValid = false; + } + if(kargs.M % GemmPipeline::GetVectorSizeA() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!"); + } + AsTesnorIsValid = false; + } + } + }); + + bool BsTesnorIsValid = {true}; + static_for<0, NumBTensor, 1>{}([&](auto index) { + using BiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support N that is not a multiple of NPerBlock without padding!"); + } + BsTesnorIsValid = false; + } + if(kargs.N % GemmPipeline::GetVectorSizeB() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!"); + } + BsTesnorIsValid = false; + } + } + else + { + if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && + GemmPipeline::kPadK == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support K that is not a multiple of k_batch * KPerBlock " + "without padding!"); + } + BsTesnorIsValid = false; + } + if(kargs.K % GemmPipeline::GetVectorSizeB() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!"); + } + BsTesnorIsValid = false; + } + } + }); + + bool DTesnorIsValid = {true}; + static_for<0, NumDTensor, 1>{}([&](auto index) { + using DiLayout = remove_cvref_t>; + if(std::is_same_v == false) + { + DTesnorIsValid = false; + } + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of " + "NPerBlock without padding!"); + } + DTesnorIsValid = false; + } + if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!"); + } + DTesnorIsValid = false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of " + "MPerBlock without padding!"); + } + DTesnorIsValid = false; + } + if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!"); + } + DTesnorIsValid = false; + } + } + }); + + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support N that is not a multiple of NPerBlock without padding!"); + } + return false; + } + if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!"); + } + return false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support M that is not a multiple of MPerBlock without padding!"); + } + return false; + } + if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!"); + } + return false; + } + } + return AsTesnorIsValid && BsTesnorIsValid && DTesnorIsValid; + } + + template + CK_TILE_DEVICE static auto + MakeGemmTensorViews(const std::array& as_ptr, + const std::array& bs_ptr, + const std::array& ds_ptr, + EDataType* e_ptr, + const KernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset) + { + static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); + + const auto& as_tensor_view = generate_tuple( + [&](auto i) { + using AiLayout = remove_cvref_t>; + using AiDataType = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + static_cast(as_ptr[i]), + make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_As[i], 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + static_cast(as_ptr[i]), + make_tuple(splitk_batch_offset.splitted_k, kargs.M), + make_tuple(kargs.stride_As[i], 1), + number{}, + number<1>{}); + } + }, + number{}); + + const auto& bs_tensor_view = generate_tuple( + [&](auto i) { + using BiLayout = remove_cvref_t>; + using BiDataType = remove_cvref_t>; + if constexpr(std::is_same_v) + { + if constexpr(TilePartitioner::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = splitk_batch_offset.splitted_k / K1; + constexpr index_t VectorSizeB = + std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = + make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), + make_tuple(kargs.N * K1, K1, I1), + number{}, + number<1>{}); + const auto b_n_k_desc = transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(kargs.N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return make_tensor_view( + static_cast(bs_ptr[i]), b_n_k_desc); + } + else + { + return make_naive_tensor_view( + bs_ptr[i], + make_tuple(splitk_batch_offset.splitted_k, kargs.N), + make_tuple(kargs.stride_Bs[i], 1), + number{}, + number<1>{}); + } + } + else + { + if constexpr(TilePartitioner::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = splitk_batch_offset.splitted_k / K1; + constexpr index_t VectorSizeB = + std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = + make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), + make_tuple(kargs.N * K1, K1, I1), + number{}, + number<1>{}); + const auto b_n_k_desc = transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(kargs.N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + return make_tensor_view( + static_cast(bs_ptr[i]), b_n_k_desc); + } + else + { + if constexpr(GemmPipeline::Preshuffle) + { + index_t kFlatK = + GemmPipeline::BlockGemmShape::flatKPerWarp * + (splitk_batch_offset.splitted_k / + TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{})); + index_t kFlatN = kargs.N * kargs.K / kFlatK; + + return make_naive_tensor_view( + bs_ptr[i], + make_tuple(kFlatN, kFlatK), + make_tuple(kFlatK, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + bs_ptr[i], + make_tuple(kargs.N, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_Bs[i], 1), + number{}, + number<1>{}); + } + } + } + }, + number{}); + + const auto& ds_tensor_view = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + using DDataType_ = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + static_cast(ds_ptr[i]), + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.stride_Ds[i], 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + static_cast(ds_ptr[i]), + make_tuple(kargs.N, kargs.M), + make_tuple(kargs.stride_Ds[i], 1), + number{}, + number<1>{}); + } + }, + number{}); + + // TODO: enable vector write for C in ColMajor + const auto& e_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + e_ptr, + make_tuple(kargs.M, kargs.N), // arguments not matching with flatmm. + make_tuple(kargs.stride_E, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + e_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(1, kargs.stride_E), + number<1>{}, + number<1>{}); + } + }(); + + return make_tuple(as_tensor_view, bs_tensor_view, ds_tensor_view, e_tensor_view); + } + + template + CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) + { + const auto& as_pad_view = generate_tuple( + [&](auto i) { + const auto& a_tensor_view = views.at(I0); + using AiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(a_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(a_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + const auto& b_flat_pad_view = views.at(I1); + + const auto& bs_pad_view = generate_tuple( + [&](auto i) { + const auto& b_tensor_view = views.at(I1); + using BiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(b_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(b_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + const auto& ds_pad_view = generate_tuple( + [&](auto i) { + const auto& d_tensor_view = views.at(I2); + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(d_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(d_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + // TODO vector write in for C in ColMajor + const auto& e_pad_view = [&]() { + const auto& e_tensor_view = views.at(I3); + if constexpr(std::is_same_v) + { + return pad_tensor_view(e_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(e_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + if constexpr(GemmPipeline::Preshuffle) + { + // For flatmm, we need to use the flat B tensor view + return make_tuple(as_pad_view, b_flat_pad_view, ds_pad_view, e_pad_view); + } + else + { + return make_tuple(as_pad_view, bs_pad_view, ds_pad_view, e_pad_view); + } + } + + template + CK_TILE_DEVICE static auto + MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) + { + const auto& as_pad_view = views.at(I0); + const auto& bs_pad_view = views.at(I1); + const auto& ds_pad_view = views.at(I2); + const auto& e_pad_view = views.at(I3); + + const auto& as_block_window = generate_tuple( + [&](auto i) { + using AiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window(as_pad_view[i], + make_tuple(number{}, + number{}), + {i_m, 0}); + } + else + { + return make_tile_window(as_pad_view[i], + make_tuple(number{}, + number{}), + {0, i_m}); + } + }, + number{}); + + const auto& bs_block_window = generate_tuple( + [&](auto i) { + using BiLayout = remove_cvref_t>; + if constexpr(GemmPipeline::Preshuffle) + { + return make_tile_window( + bs_pad_view[i], + make_tuple(number{}, + number{}), + {static_cast(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)), + 0}); + } + else + { + if constexpr(std::is_same_v) + { + return make_tile_window(bs_pad_view[i], + make_tuple(number{}, + number{}), + {i_n, 0}); + } + else + { + return make_tile_window(bs_pad_view[i], + make_tuple(number{}, + number{}), + {0, i_n}); + } + } + }, + number{}); + + const auto ds_block_window = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {i_m, i_n}); + } + else + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {i_n, i_m}); + } + }, + number{}); + + auto e_block_window = make_tile_window( + e_pad_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + + return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window); + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @param as_ptr input As pointer + * @param bs_ptr input Bs pointer + * @param ds_ptr input Ds pointer + * @param e_ptr output E pointer + * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param kargs GEMM kernel arguments + * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch. + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + */ + template + CK_TILE_DEVICE static void RunGemm(const std::array& as_ptr, + const std::array& bs_ptr, + const std::array& ds_ptr, + EDataType* e_ptr, + void* smem_ptr_0, + const KernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews( + as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); + + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + + // Run GEMM cooperatively by whole workgroup. + const auto& as_block_window = gemm_tile_windows.at(I0); + const auto& bs_block_window = gemm_tile_windows.at(I1); + const auto& ds_block_window = gemm_tile_windows.at(I2); + + const auto& c_block_tile = GemmPipeline{}.template operator()( + as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0); + + if(UseDefaultScheduler || (get_warp_id() == 0)) + { + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I3); + + EpiloguePipeline{}.template + operator()( + c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism. + * + * @param as_ptr input As pointer + * @param bs_ptr input Bs pointer + * @param ds_ptr input Ds pointer + * @param e_ptr output E pointer + * @param smem_ptr_0 The starting pointer of 1st shared memory block. + * @param smem_ptr_1 The starting pointer of 2nd shared memory block. + * @param kargs GEMM kernel arguments + * @param splitk_batch_offset Utility structure used to calculate k batch. + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + */ + CK_TILE_DEVICE static void RunGemm2LDS(const std::array& as_ptr, + const std::array& bs_ptr, + const std::array& ds_ptr, + EDataType* e_ptr, + void* __restrict__ smem_ptr_0, + void* __restrict__ smem_ptr_1, + const KernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews( + as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); + + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + + // Run GEMM cooperatively by whole workgroup. + const auto& as_block_window = gemm_tile_windows.at(I0); + const auto& bs_block_window = gemm_tile_windows.at(I1); + const auto& ds_block_window = gemm_tile_windows.at(I2); + + const auto& c_block_tile = GemmPipeline{}.template operator()( + as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0, smem_ptr_1); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I3); + + EpiloguePipeline{}.template + operator()( + c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } + + // Non-persistent kernel entry point + template > + CK_TILE_DEVICE void operator()(KernelArgs kargs) const + { + const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x); + const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId); + const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + + const SplitKBatchOffset splitk_batch_offset(kargs); + + // options + std::array as_ptr; + static_for<0, NumATensor, 1>{}([&](auto i) { + as_ptr[i] = static_cast(kargs.as_ptr[i]) + + splitk_batch_offset.as_k_split_offset[i]; + }); + + std::array bs_ptr; + static_for<0, NumBTensor, 1>{}([&](auto i) { + bs_ptr[i] = static_cast(kargs.bs_ptr[i]) + + splitk_batch_offset.bs_k_split_offset[i]; + }); + + EDataType* e_ptr = static_cast(kargs.e_ptr); + + // allocate LDS + __shared__ char smem_ptr_0[GetSmemSize()]; + + if constexpr(GemmPipeline::DoubleSmemBuffer == true) + { + __shared__ char smem_ptr_1[GetSmemSize()]; + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm2LDS(as_ptr, + bs_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr_0, + smem_ptr_1, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + } + else + { + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1); + RunGemm(as_ptr, + bs_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + } + } + + // Persistent kernel entry point + template , typename = void> + CK_TILE_DEVICE void operator()(KernelArgs kargs) const + { + const auto grid_size = __builtin_amdgcn_readfirstlane(get_grid_size()); + const auto num_tiles = + __builtin_amdgcn_readfirstlane(TilePartitioner::GridSize(kargs.M, kargs.N)); + const auto num_work = __builtin_amdgcn_readfirstlane(num_tiles * kargs.k_batch); + auto block_id = __builtin_amdgcn_readfirstlane(get_block_id()); + + while(block_id < num_work) + { + // Get the tile index for this block + const auto tile_idx = __builtin_amdgcn_readfirstlane(block_id % num_tiles); + const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx); + const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + + // Get the SplitK offset for this block + const auto k_batch = __builtin_amdgcn_readfirstlane(block_id / num_tiles); + const SplitKBatchOffset splitk_batch_offset(kargs, k_batch); + + std::array as_ptr; + static_for<0, NumATensor, 1>{}([&](auto i) { + as_ptr[i] = static_cast(kargs.as_ptr[i]) + + splitk_batch_offset.as_k_split_offset[i]; + }); + + std::array bs_ptr; + static_for<0, NumBTensor, 1>{}([&](auto i) { + bs_ptr[i] = static_cast(kargs.bs_ptr[i]) + + splitk_batch_offset.bs_k_split_offset[i]; + }); + + EDataType* e_ptr = static_cast(kargs.e_ptr); + + // allocate LDS + __shared__ char smem_ptr_0[GetSmemSize()]; + // Run the GEMM + if constexpr(GemmPipeline::DoubleSmemBuffer == true) + { + __shared__ char smem_ptr_1[GetSmemSize()]; + if constexpr(!(EpiloguePipeline::MemoryOperation == + memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm2LDS(as_ptr, + bs_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr_0, + smem_ptr_1, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + } + else + { + if constexpr(!(EpiloguePipeline::MemoryOperation == + memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm(as_ptr, + bs_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + } + // Advance to the next work item + block_id += grid_size; + if(block_id >= num_work) + { + break; + } + } + } +}; +} // namespace ck_tile diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp index 79bd51d65c..f654d1a917 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp +++ b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp @@ -242,21 +242,20 @@ class TestCkTileBatchedGemm : public ::testing::Test c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - ck_tile::BatchedGemmHostArgs args; - args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); - args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); - args.e_ptr = c_m_n_dev_buf.GetDeviceBuffer(); - args.k_batch = 1; - args.M = M; - args.N = N; - args.K = K; - args.stride_A = StrideA; - args.stride_B = StrideB; - args.stride_E = StrideC; - args.batch_stride_A = BatchStrideA; - args.batch_stride_B = BatchStrideB; - args.batch_stride_E = BatchStrideC; - args.batch_count = BatchCount; + ck_tile::BatchedGemmHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + 1, + M, + N, + K, + StrideA, + StrideB, + StrideC, + BatchStrideA, + BatchStrideB, + BatchStrideC, + BatchCount}; invoke_batched_gemm(args, ck_tile::stream_config{nullptr, false}); diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc index 9e4c036655..4321709ea5 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc @@ -25,7 +25,7 @@ template -float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { if constexpr(Persistent) diff --git a/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc index afa6912e0f..a967b92e7f 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc @@ -158,7 +158,7 @@ template -float gemm(const ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& s); +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); template args = {a_m_k_dev_buf.GetDeviceBuffer(), - b_k_n_dev_buf.GetDeviceBuffer(), - {}, - c_m_n_dev_buf.GetDeviceBuffer(), - kbatch, - M, - N, - K, - stride_A, - stride_B, - {}, - stride_C}; + ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + stride_C}; float ave_time; if(persistent) diff --git a/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp index 99a1e50a6f..bd197150a4 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp @@ -411,4 +411,4 @@ template -float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc index 1980648391..860541ef18 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc @@ -14,7 +14,7 @@ template -float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { using GemmShape = ck_tile::TileGemmShape< @@ -63,119 +63,120 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile: float ave_time{0}; - const auto Run = - [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - dim3 grids; - if constexpr(Persistent) - { - grids = Kernel::MaxOccupancyGridSize(s); - } - else - { - grids = Kernel::GridSize(args.M, args.N, args.k_batch); - } - constexpr dim3 blocks = Kernel::BlockSize(); + dim3 grids; + if constexpr(Persistent) + { + grids = Kernel::MaxOccupancyGridSize(s); + } + else + { + grids = Kernel::GridSize(args.M, args.N, args.k_batch); + } + constexpr dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw ArgumentsNotSupportedException( - "Wrong! Arguments not supported! Skipping gemm!\n"); - } + if(!Kernel::IsSupportedArgument(kargs)) + { + throw ArgumentsNotSupportedException( + "Wrong! Arguments not supported! Skipping gemm!\n"); + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << GemmPipelineProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - static constexpr ck_tile::index_t APackedSize = - std::is_same_v ? 2 : 1; - static constexpr ck_tile::index_t BPackedSize = - std::is_same_v ? 2 : 1; + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << GemmPipelineProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + static constexpr ck_tile::index_t APackedSize = + std::is_same_v ? 2 : 1; + static constexpr ck_tile::index_t BPackedSize = + std::is_same_v ? 2 : 1; - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; - auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; + auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; + auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; - ck_tile::RotatingMemWrapper rotating_mem( - kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); - rotating_mem.Print(); + ck_tile::RotatingMemWrapper rotating_mem( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); - auto run_flush_cache = [&]() { - // flush icache - ck_tile::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); - }; - ave_time = ck_tile::launch_kernel_preprocess( - s, - run_flush_cache, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); - } - else - { - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); - } - return ave_time; - }; + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + ave_time = ck_tile::launch_kernel_preprocess( + s, + run_flush_cache, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + else + { + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + return ave_time; + }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 7b519760b9..9adf9ec185 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -91,8 +91,7 @@ class TestCkTileGemmPipeline : public ::testing::Test // TODO: expose tile size through test t-param ? template - void invoke_gemm(const ck_tile::GemmHostArgs& args, - const ck_tile::stream_config& s) + void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { // TODO: This should be parameterized in tests constexpr ck_tile::index_t M_Tile = 256; @@ -324,9 +323,9 @@ class TestCkTileGemmPipeline : public ::testing::Test return stride; }; - std::size_t stride_A = f_get_default_stride(M, K, StrideA, ALayout{}); - std::size_t stride_B = f_get_default_stride(K, N, StrideB, BLayout{}); - std::size_t stride_C = f_get_default_stride(M, N, StrideC, CLayout{}); + ck_tile::index_t stride_A = f_get_default_stride(M, K, StrideA, ALayout{}); + ck_tile::index_t stride_B = f_get_default_stride(K, N, StrideB, BLayout{}); + ck_tile::index_t stride_C = f_get_default_stride(M, N, StrideC, CLayout{}); ck_tile::HostTensor a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{})); ck_tile::HostTensor b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{})); @@ -345,17 +344,16 @@ class TestCkTileGemmPipeline : public ::testing::Test c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - ck_tile::GemmHostArgs args; - args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); - args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); - args.e_ptr = c_m_n_dev_buf.GetDeviceBuffer(); - args.k_batch = kbatch; - args.M = M; - args.N = N; - args.K = K; - args.stride_A = stride_A; - args.stride_B = stride_B; - args.stride_E = stride_C; + ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + stride_C}; invoke_gemm(args, ck_tile::stream_config{nullptr, false}); diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp index 7dd91077b1..c08951435e 100644 --- a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp @@ -10,7 +10,7 @@ #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" -#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" struct ElementWiseAddAdd @@ -95,7 +95,7 @@ class TestCkTileGemmMultiD : public ::testing::Test typename DsLayout, typename ELayout, typename CDEElementWise = ck_tile::element_wise::PassThrough> - void invoke_gemm_multi_d(const ck_tile::GemmHostArgs& args, + void invoke_gemm_multi_d(const ck_tile::GemmMultiDHostArgs& args, const ck_tile::stream_config& s) { constexpr ck_tile::index_t M_Tile = 256; @@ -189,7 +189,7 @@ class TestCkTileGemmMultiD : public ::testing::Test UniversalGemmProblem::TransposeC, memory_operation>>; - using Kernel = ck_tile::GemmKernel; + using Kernel = ck_tile::GemmKernelMultiD; auto kargs = Kernel::MakeKernelArgs(args); const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); @@ -345,18 +345,18 @@ class TestCkTileGemmMultiD : public ::testing::Test d1_m_n_dev_buf.GetDeviceBuffer()}; std::array stridesDs = {StrideD0, StrideD1}; - ck_tile::GemmHostArgs args({a_m_k_dev_buf.GetDeviceBuffer(), - b_k_n_dev_buf.GetDeviceBuffer(), - ds_ptr_buf, - e_m_n_dev_buf.GetDeviceBuffer(), - k_batch, - M, - N, - K, - StrideA, - StrideB, - stridesDs, - StrideE}); + ck_tile::GemmMultiDHostArgs args({a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + ds_ptr_buf, + e_m_n_dev_buf.GetDeviceBuffer(), + k_batch, + M, + N, + K, + StrideA, + StrideB, + stridesDs, + StrideE}); invoke_gemm_multi_d - void invoke_gemm(const ck_tile::GemmHostArgs& args, - const ck_tile::stream_config& s) + void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { // TODO: This should be parameterized in tests // constexpr ck_tile::index_t M_Tile = 128; @@ -314,9 +313,9 @@ class TestCkTileGemmPipeline : public ::testing::Test return stride; }; - std::size_t stride_A = f_get_default_stride(M, K, StrideA, ALayout{}); - std::size_t stride_B = f_get_default_stride(K, N, StrideB, BLayout{}); - std::size_t stride_C = f_get_default_stride(M, N, StrideC, CLayout{}); + ck_tile::index_t stride_A = f_get_default_stride(M, K, StrideA, ALayout{}); + ck_tile::index_t stride_B = f_get_default_stride(K, N, StrideB, BLayout{}); + ck_tile::index_t stride_C = f_get_default_stride(M, N, StrideC, CLayout{}); ck_tile::HostTensor a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{})); ck_tile::HostTensor b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{})); @@ -346,17 +345,16 @@ class TestCkTileGemmPipeline : public ::testing::Test c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - ck_tile::GemmHostArgs args; - args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); - args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); - args.e_ptr = c_m_n_dev_buf.GetDeviceBuffer(); - args.k_batch = kbatch; - args.M = M; - args.N = N; - args.K = K; - args.stride_A = stride_A; - args.stride_B = stride_B; - args.stride_E = stride_C; + ck_tile::GemmHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + stride_C}; invoke_gemm(args, ck_tile::stream_config{nullptr, false}); diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp index 54f772f89e..79e29f8b99 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp @@ -51,7 +51,7 @@ class TestCkTileGroupedGemm : public ::testing::Test static const ck_tile::index_t K_Warp_Tile = 16; }; - using grouped_gemm_kargs = ck_tile::GemmHostArgs; + using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; std::size_t get_workspace_size(const std::vector& gemm_descs) { return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); @@ -437,7 +437,7 @@ class TestCkTileGroupedGemm : public ::testing::Test void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); gemm_descs.push_back( - {p_a, p_b, {}, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], {}, stride_Cs[i]}); + {p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); } ck_tile::DeviceMem gemm_workspace; @@ -451,18 +451,18 @@ class TestCkTileGroupedGemm : public ::testing::Test const bool splitk = gemm_descs[0].k_batch > 1; for(const auto& arg : gemm_descs) { - kargs.emplace_back(ck_tile::GemmKernelArgs<>{arg.a_ptr, - arg.b_ptr, - {}, - arg.e_ptr, - arg.M, - arg.N, - arg.K, - arg.stride_A, - arg.stride_B, - {}, - arg.stride_E, - arg.k_batch}); + kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<>{{arg.a_ptr}, + {arg.b_ptr}, + {/*arg.ds_ptr*/}, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + {arg.stride_A}, + {arg.stride_B}, + {/*arg.stride_Ds*/}, + arg.stride_E, + arg.k_batch}); } const auto stream = ck_tile::stream_config{nullptr, false, 1}; ck_tile::hip_check_error( diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 0b38c44a1a..6796121328 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -233,7 +233,7 @@ struct GemmKernel {{ static constexpr bool kPadN = {pad_n}; static constexpr bool kPadK = {pad_k}; - static float launch(ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) {{ + static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ static constexpr bool permuteA = false; static constexpr bool permuteB = false; static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"}; @@ -335,7 +335,7 @@ struct GemmKernel {{ auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; ck_tile::RotatingMemWrapper rotating_mem( - kargs.a_ptr, kargs.b_ptr, stream.rotating_count_, size_a_buffer, size_b_buffer); + kargs.as_ptr[0], kargs.bs_ptr[0], stream.rotating_count_, size_a_buffer, size_b_buffer); rotating_mem.Print(); auto run_flush_cache = [&]() {{ @@ -680,7 +680,7 @@ struct GemmDispatcher { // Use a static local variable static std::unordered_map< std::string, - std::vector(ck_tile::GemmHostArgs<>&, const ck_tile::stream_config&)>>> + std::vector(ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>> kernel_map; return kernel_map; } @@ -705,7 +705,7 @@ struct GemmDispatcher { warp_tile_n, warp_tile_k, ) = tile[j] - content += f"""[=](ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) {{ """ + content += f"""[=](ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ """ content += f""" if(structured_sparsity){{ // SMFMA""" sparse = ( @@ -746,7 +746,7 @@ struct GemmDispatcher { content += """ } template - static std::tuple run_kernel(ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) + static std::tuple run_kernel(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) { std::string name = Kernel::get_name(); float avg_time = Kernel::launch(args, stream); diff --git a/tile_engine/ops/gemm/gemm_profiler.hpp b/tile_engine/ops/gemm/gemm_profiler.hpp index 2b0cbe7880..fdad363f7c 100644 --- a/tile_engine/ops/gemm/gemm_profiler.hpp +++ b/tile_engine/ops/gemm/gemm_profiler.hpp @@ -22,7 +22,7 @@ class GemmProfiler void benchmark(GemmProblem& gemm_problem, std::vector( - ck_tile::GemmHostArgs<>&, const ck_tile::stream_config&)>>& callables) + ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>& callables) { const ALayout layout_a = ALayout{}; const BLayout layout_b = BLayout{}; @@ -89,10 +89,9 @@ class GemmProfiler c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - ck_tile::GemmHostArgs<> gemm_args = { + ck_tile::GemmHostArgs gemm_args = { a_m_k_dev_buf.GetDeviceBuffer(), b_k_n_dev_buf.GetDeviceBuffer(), - {}, // ds_ptr c_m_n_dev_buf.GetDeviceBuffer(), gemm_problem.split_k_, gemm_problem.m_, @@ -100,7 +99,6 @@ class GemmProfiler gemm_problem.k_, gemm_problem.stride_a_, gemm_problem.stride_b_, - {}, // stride_Ds gemm_problem.stride_c_, };