From a5a7f2675f0a5ee2977af37875f0befefa495f9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Sat, 28 Dec 2024 14:40:17 +0100 Subject: [PATCH] [CK TILE] GEMM and Batched GEMM SplitK support (#1724) * [CK TILE] Add split K support in GEMM * Updates * Fixes * rebase * fix * Fix * fixes * support for batched gemm [ROCm/composable_kernel commit: af66494880fc6256e5e1ced779b6d80446726970] --- example/ck_tile/03_gemm/gemm_basic.hpp | 6 +- example/ck_tile/03_gemm/run_gemm_example.inc | 8 +- example/ck_tile/03_gemm/universal_gemm.cpp | 18 +- .../ck_tile/16_batched_gemm/batched_gemm.cpp | 13 +- .../ck_tile/16_batched_gemm/batched_gemm.hpp | 3 +- .../run_batched_gemm_example.inc | 4 + .../ops/epilogue/cshuffle_epilogue.hpp | 31 +++- .../ops/epilogue/default_2d_epilogue.hpp | 26 ++- .../ops/gemm/kernel/batched_gemm_kernel.hpp | 32 +++- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 164 +++++++++++++----- .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 2 + .../pipeline/gemm_pipeline_ag_bg_cr_mem.hpp | 2 + .../gemm_pipeline_agmem_bgmem_creg_v1.hpp | 2 + ...ine_agmem_bgmem_creg_v1_default_policy.hpp | 14 +- .../gemm_pipeline_agmem_bgmem_creg_v2.hpp | 2 + ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 2 + .../batched_gemm/test_batched_gemm_util.hpp | 3 +- test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 4 +- 18 files changed, 245 insertions(+), 91 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_basic.hpp b/example/ck_tile/03_gemm/gemm_basic.hpp index 58cdaea7d8..38c0a279db 100644 --- a/example/ck_tile/03_gemm/gemm_basic.hpp +++ b/example/ck_tile/03_gemm/gemm_basic.hpp @@ -54,8 +54,7 @@ using CDataType = Types::CDataType; auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; - arg_parser.insert("b", "1", "batch size") - .insert("m", "3840", "m dimension") + arg_parser.insert("m", "3840", "m dimension") .insert("n", "4096", "n dimension") .insert("k", "2048", "k dimension") .insert("a_layout", "R", "A tensor data layout - Row by default") @@ -68,7 +67,8 @@ auto create_args(int argc, char* argv[]) .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel") - .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer"); + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 68df389bfc..56d0348bd6 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -64,9 +64,9 @@ int run_gemm_example_with_layouts(int argc, ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); - ck_tile::index_t batch_size = arg_parser.get_int("b"); - int n_warmup = arg_parser.get_int("warmup"); - int n_repeat = arg_parser.get_int("repeat"); + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); using namespace ck_tile::literals; @@ -133,7 +133,7 @@ int run_gemm_example_with_layouts(int argc, stride_A, stride_B, stride_C, - batch_size, + kbatch, n_warmup, n_repeat); diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index 6c87ca0087..1a9e025a9b 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -22,7 +22,7 @@ #endif template -float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) +float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) // Memory friendly for Interwave scheduler @@ -78,7 +78,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) #endif ck_tile::GemmPipelineProblem>; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K); + const ck_tile::index_t k_grain = args.k_batch * K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); @@ -106,17 +108,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) has_hot_loop_v, tail_number_v>>; using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKargs(args.p_a, - args.p_b, - args.p_c, - args.M, - args.N, - args.K, - args.stride_A, - args.stride_B, - args.stride_C); + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); constexpr dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.cpp b/example/ck_tile/16_batched_gemm/batched_gemm.cpp index 9b4ed9a9e7..b9c9eaa583 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.cpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.cpp @@ -70,20 +70,25 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre using CodegenGemmTraits = ck_tile::TileGemmTraits; - using CodegenPipelineProblem = ck_tile:: GemmPipelineProblem; - - using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy; + using CodegenGemmPipeline = + ck_tile::GemmPipelineAGmemBGmemCRegV1; // ToDo: Will add the codegen part to test different pipeline policies in GEMM. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. using Kernel = ck_tile::BatchedGemmKernel; auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); constexpr dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + if(s.log_level_ > 0) { std::cout << "Launching kernel with args:" diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.hpp b/example/ck_tile/16_batched_gemm/batched_gemm.hpp index f0c0c9efba..62f0058fd1 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.hpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.hpp @@ -49,7 +49,8 @@ auto create_args(int argc, char* argv[]) .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel") - .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer"); + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); diff --git a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc index 4e7218b5b1..c14bb5668c 100644 --- a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc +++ b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc @@ -17,6 +17,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::index_t batch_stride_B, ck_tile::index_t batch_stride_C, ck_tile::index_t batch_count, + ck_tile::index_t kbatch, int n_warmup, int n_repeat) { @@ -24,6 +25,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); + args.k_batch = kbatch; args.M = M; args.N = N; args.K = K; @@ -79,6 +81,7 @@ int run_batched_gemm_example_with_layouts(int argc, ck_tile::index_t batch_stride_B = arg_parser.get_int("batch_stride_b"); ck_tile::index_t batch_stride_C = arg_parser.get_int("batch_stride_c"); ck_tile::index_t batch_count = arg_parser.get_int("batch_count"); + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); int n_warmup = arg_parser.get_int("warmup"); int n_repeat = arg_parser.get_int("repeat"); @@ -159,6 +162,7 @@ int run_batched_gemm_example_with_layouts(int argc, batch_stride_B, batch_stride_C, batch_count, + kbatch, n_warmup, n_repeat); diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 9625b137bd..01105d2a82 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -56,6 +56,13 @@ struct CShuffleEpilogue // No additional shared memory needed CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } + CK_TILE_HOST_DEVICE static constexpr bool IsOutputTransposed() + { + // TODO: At now CShuffle doesn't allow to vector store after permute. + // It should be fixed and this function should return true. + return false; + } + template CK_TILE_DEVICE void permute_tile_data(OAccTile& o_acc_tile) { @@ -111,7 +118,9 @@ struct CShuffleEpilogue } } - template + template CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile) { const auto& current_window_origin = o_dram_window_tmp.get_window_origin(); @@ -158,12 +167,26 @@ struct CShuffleEpilogue // Store the tile data to the permuted location if constexpr(kPadM || kPadN) { - store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); + if constexpr(out_memory_data_op == memory_operation_enum::set) + { + store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); + } + else + { + update_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); + } buffer_store_fence(); } else { - store_tile(o_dram_window_tmp, cast_tile(o_acc_tile)); + if constexpr(out_memory_data_op == memory_operation_enum::set) + { + store_tile(o_dram_window_tmp, cast_tile(o_acc_tile)); + } + else + { + update_tile(o_dram_window_tmp, cast_tile(o_acc_tile)); + } } } }; diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index 7c5d5a6f31..177573de34 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -35,21 +35,39 @@ struct Default2DEpilogue CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } + CK_TILE_HOST_DEVICE static constexpr bool IsOutputTransposed() { return false; } + // TODO: this function assume store out vector size is the same as OAccTile last dimension size // how do we fix this ? - template + template CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile) { // TODO: this is ugly if constexpr(UseRawStore && (kPadM || kPadN)) { - store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); + if constexpr(out_memory_data_op == memory_operation_enum::set) + { + store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); + } + else + { + update_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); + } buffer_store_fence(); } else { - store_tile(o_dram_window_tmp, cast_tile(o_acc_tile)); + if constexpr(out_memory_data_op == memory_operation_enum::set) + { + store_tile(o_dram_window_tmp, cast_tile(o_acc_tile)); + } + else + { + update_tile(o_dram_window_tmp, cast_tile(o_acc_tile)); + } } } }; diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index 07a4cf8fbe..eaf66237af 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -67,9 +67,10 @@ struct BatchedGemmKernel : public GemmKernel(kargs.a_ptr) + batch_offset_A; + const ADataType* a_ptr = static_cast(kargs.a_ptr) + batch_offset_A + + splitk_batch_offset.a_k_split_offset; const auto batch_stride_B = __builtin_amdgcn_readfirstlane(kargs.batch_stride_B); const auto batch_offset_B = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_B); - const BDataType* b_ptr = static_cast(kargs.b_ptr) + batch_offset_B; + const BDataType* b_ptr = static_cast(kargs.b_ptr) + batch_offset_B + + splitk_batch_offset.b_k_split_offset; const auto batch_stride_C = __builtin_amdgcn_readfirstlane(kargs.batch_stride_C); const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_C); CDataType* c_ptr = static_cast(kargs.c_ptr) + batch_offset_C; - this->RunGemm(a_ptr, b_ptr, c_ptr, kargs, i_m, i_n); + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + if(kargs.KBatch == 1) + { + this->RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); + } + else + { + this->template RunGemm( + a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); + } } }; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 925648a886..c81a64f7ad 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -93,6 +93,7 @@ struct GemmKernel index_t stride_A; index_t stride_B; index_t stride_C; + index_t KBatch; }; CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs) @@ -105,28 +106,72 @@ struct GemmKernel hostArgs.K, hostArgs.stride_A, hostArgs.stride_B, - hostArgs.stride_C}; + hostArgs.stride_C, + hostArgs.k_batch}; } - // CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const void* a_ptr, - // const void* b_ptr, - // void* c_ptr, - // index_t M, - // index_t N, - // index_t K, - // index_t stride_A, - // index_t stride_B, - // index_t stride_C) - // { - // return GemmKernelArgs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C}; - // } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); } + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(const GemmKernelArgs& kargs, + const std::size_t k_id = blockIdx.z) + { + constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); + const index_t K_t = kargs.KBatch * K1; + const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; + + if constexpr(std::is_same_v) + { + a_k_split_offset = k_id * KRead; + } + else if constexpr(std::is_same_v) + { + a_k_split_offset = k_id * KRead * kargs.stride_A; + } + + if constexpr(std::is_same_v) + { + b_k_split_offset = k_id * KRead * kargs.stride_B; + } + else if constexpr(std::is_same_v) + { + b_k_split_offset = k_id * KRead; + } + + if(k_id < static_cast(kargs.KBatch - 1)) + { + splitted_k = KRead; + } + else + { + splitted_k = kargs.K - KRead * (kargs.KBatch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t splitted_k; + }; + CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs) { + constexpr bool is_output_c_reg_transposed = + EpiloguePipeline::IsOutputTransposed() != GemmPipeline::IsTransposeC(); + if constexpr(!((GemmPipeline::VectorSizeC % 2 == 0 && + std::is_same_v && + is_output_c_reg_transposed) || + !(std::is_same_v || std::is_same_v))) + { + if(kargs.KBatch != 1) + { + return false; + } + } + if constexpr(std::is_same_v) { if(kargs.K % TilePartitioner::kK != 0 && GemmPipeline::kPadK == false) @@ -198,17 +243,19 @@ struct GemmKernel return true; } - CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType* a_ptr, - const BDataType* b_ptr, - CDataType* c_ptr, - const GemmKernelArgs& kargs) const + template + CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr, + const BDataType* b_ptr, + CDataType* c_ptr, + const GemmKernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset) { const auto& a_tensor_view = [&]() { if constexpr(std::is_same_v) { return make_naive_tensor_view( a_ptr, - make_tuple(kargs.M, kargs.K), + make_tuple(kargs.M, splitk_batch_offset.splitted_k), make_tuple(kargs.stride_A, 1), number{}, number<1>{}); @@ -217,7 +264,7 @@ struct GemmKernel { return make_naive_tensor_view( a_ptr, - make_tuple(kargs.M, kargs.K), + make_tuple(kargs.M, splitk_batch_offset.splitted_k), make_tuple(1, kargs.stride_A), number<1>{}, number<1>{}); @@ -229,7 +276,7 @@ struct GemmKernel { return make_naive_tensor_view( b_ptr, - make_tuple(kargs.N, kargs.K), + make_tuple(kargs.N, splitk_batch_offset.splitted_k), make_tuple(1, kargs.stride_B), number<1>{}, number<1>{}); @@ -238,7 +285,7 @@ struct GemmKernel { return make_naive_tensor_view( b_ptr, - make_tuple(kargs.N, kargs.K), + make_tuple(kargs.N, splitk_batch_offset.splitted_k), make_tuple(kargs.stride_B, 1), number{}, number<1>{}); @@ -248,7 +295,7 @@ struct GemmKernel const auto& c_tensor_view = [&]() { if constexpr(std::is_same_v) { - return make_naive_tensor_view( + return make_naive_tensor_view( c_ptr, make_tuple(kargs.M, kargs.N), make_tuple(kargs.stride_C, 1), @@ -257,7 +304,7 @@ struct GemmKernel } else { - return make_naive_tensor_view( + return make_naive_tensor_view( c_ptr, make_tuple(kargs.M, kargs.N), make_tuple(1, kargs.stride_C), @@ -270,7 +317,7 @@ struct GemmKernel } template - CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView& views) const + CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) { const auto& a_pad_view = [&]() { const auto& a_tensor_view = views.at(I0); @@ -330,8 +377,8 @@ struct GemmKernel } template - CK_TILE_DEVICE auto - MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) const + CK_TILE_DEVICE static auto + MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) { const auto& a_pad_view = views.at(I0); const auto& a_block_window = make_tile_window( @@ -363,23 +410,27 @@ struct GemmKernel * @param kargs GEMM kernel arguments * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + * @tparam DstInMemOp Destination memory operation (default: set). */ - CK_TILE_DEVICE void RunGemm(const ADataType* a_ptr, - const BDataType* b_ptr, - CDataType* c_ptr, - const GemmKernelArgs& kargs, - const index_t block_idx_m, - const index_t block_idx_n) const + template + CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, + const BDataType* b_ptr, + CDataType* c_ptr, + void* smem_ptr, + const GemmKernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) { // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); + ; + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - // allocate LDS - __shared__ char smem_ptr[GetSmemSize()]; - - const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K); + const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); @@ -389,18 +440,43 @@ struct GemmKernel // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I2); - EpiloguePipeline{}(c_block_window, c_block_tile); + + constexpr bool is_output_c_reg_transposed = + EpiloguePipeline::IsOutputTransposed() != GemmPipeline::IsTransposeC(); + if constexpr((DstInMemOp == memory_operation_enum::set) || (sizeof(CDataType) > 2) || + (GemmPipeline::VectorSizeC % 2 == 0 && + std::is_same_v && + is_output_c_reg_transposed)) + { + EpiloguePipeline{} + .template operator()( + c_block_window, c_block_tile); + } } CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const { const auto [i_m, i_n] = TilePartitioner{}(); + const SplitKBatchOffset splitk_batch_offset(kargs); // options - const ADataType* a_ptr = static_cast(kargs.a_ptr); - const BDataType* b_ptr = static_cast(kargs.b_ptr); - CDataType* c_ptr = static_cast(kargs.c_ptr); + const ADataType* a_ptr = + static_cast(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; + const BDataType* b_ptr = + static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; + CDataType* c_ptr = static_cast(kargs.c_ptr); - RunGemm(a_ptr, b_ptr, c_ptr, kargs, i_m, i_n); + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + if(kargs.KBatch == 1) + { + RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); + } + else + { + RunGemm( + a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); + } } }; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index a72728b4a0..40628b1868 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -82,6 +82,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 return Policy::template GetSmemSize(); } + CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); } + template struct PipelineImpl : public PipelineImplBase { diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index e2e94cf92b..c7a74c81e0 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -132,6 +132,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem return Policy::template GetSmemSize(); } + CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); } + template struct PipelineImpl : public PipelineImplBase { diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index 822748c69b..11a18e52c2 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -53,6 +53,8 @@ struct GemmPipelineAGmemBGmemCRegV1 return Policy::template GetSmemSize(); } + CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); } + template @@ -114,8 +116,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy { constexpr index_t smem_size_a = GetSmemSizeA(); constexpr index_t smem_size_b = GetSmemSizeB(); - index_t smem_size = 0; - smem_size += smem_size_a + smem_size_b; + constexpr index_t smem_size = smem_size_a + smem_size_b; return smem_size; } @@ -485,13 +486,14 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy } } + CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return TransposeC; } + template CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() { - constexpr bool TransposeC = false; - constexpr auto I0 = number<0>{}; - constexpr auto I1 = number<1>{}; - constexpr auto I2 = number<2>{}; + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + constexpr auto I2 = number<2>{}; using AccDataType = float; using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp index 96a5a61c8b..07d4dc441e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp @@ -36,6 +36,8 @@ struct GemmPipelineAGmemBGmemCRegV2 Policy::template MakeBLdsBlockDescriptor().get_element_space_size(); } + CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); } + template CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() { diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp index d3f3077870..e7e9b3d679 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp +++ b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp @@ -93,7 +93,7 @@ class TestCkTileBatchedGemm : public ::testing::Test auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); constexpr dim3 blocks = Kernel::BlockSize(); if(s.log_level_ > 0) @@ -186,6 +186,7 @@ class TestCkTileBatchedGemm : public ::testing::Test args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); + args.k_batch = 1; args.M = M; args.N = N; args.K = K; diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 53ead4d8d6..4b0e40060d 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -74,7 +74,9 @@ class TestCkTileGemmPipeline : public ::testing::Test ck_tile:: GemmPipelineProblem>>; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K); + const ck_tile::index_t k_grain = args.k_batch * K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);