From f1262b783a2d642dd31aecab13738e19cd838450 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Thu, 20 Mar 2025 11:17:04 +0100 Subject: [PATCH] [CK_TILE] Switch to universal gemm for batched and grouped gemms (#1919) * switch to universal gemm for batched and grouped gemms * added reviewer comments * fixed grouped gemm tests [ROCm/composable_kernel commit: 0e91d32c61cbb5c093bf947ff4e13b229a652e34] --- .../ck_tile/16_batched_gemm/batched_gemm.cpp | 297 ++++++++++++--- .../ck_tile/16_batched_gemm/batched_gemm.hpp | 40 +- .../run_batched_gemm_example.inc | 1 - .../ck_tile/17_grouped_gemm/grouped_gemm.cpp | 354 +++++++++++++----- .../ck_tile/17_grouped_gemm/grouped_gemm.hpp | 28 +- .../run_grouped_gemm_example.inc | 9 +- .../ops/gemm/kernel/batched_gemm_kernel.hpp | 4 +- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 28 +- .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 45 +-- .../test_batched_gemm_ut_cases.inc | 4 +- .../batched_gemm/test_batched_gemm_util.hpp | 178 ++++++--- .../test_grouped_gemm_ut_cases.inc | 6 +- .../grouped_gemm/test_grouped_gemm_util.hpp | 218 +++++++---- 13 files changed, 853 insertions(+), 359 deletions(-) diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.cpp b/example/ck_tile/16_batched_gemm/batched_gemm.cpp index 286fe4201d..a0cd18ec74 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.cpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.cpp @@ -18,16 +18,42 @@ template float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s) { - // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. - constexpr bool kPadM = false; - constexpr bool kPadN = false; - constexpr bool kPadK = false; - - constexpr int kBlockPerCu = 1; - - // This part comes from the Codegen +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) + // Memory friendly for Interwave scheduler constexpr ck_tile::index_t M_Tile = 128; - constexpr ck_tile::index_t N_Tile = 128; + constexpr ck_tile::index_t N_Tile = 32; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 4; + constexpr ck_tile::index_t N_Warp = 1; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 8; + + constexpr bool DoubleSmemBuffer = false; +#endif +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) + // Compute friendly for Intrawave scheduler + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool DoubleSmemBuffer = false; +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) + // Compute friendly for Intrawave scheduler + // Using the ping pong reader in the lds level + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; constexpr ck_tile::index_t K_Tile = 32; constexpr ck_tile::index_t M_Warp = 2; @@ -36,61 +62,232 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 8; + constexpr ck_tile::index_t K_Warp_Tile = 16; - using CodegenGemmShape = + constexpr bool DoubleSmemBuffer = true; +#endif + + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr bool TransposeC = false; + + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape, ck_tile::sequence, ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; - using TilePartitioner = ck_tile::GemmTile1DPartitioner; + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; - using CodegenGemmTraits = - ck_tile::TileGemmTraits; - using CodegenPipelineProblem = ck_tile:: - GemmPipelineProblem; - using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - // ToDo: Will add the codegen part to test different pipeline policies in GEMM. - // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. - using Kernel = ck_tile::BatchedGemmKernel; + using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE; - auto kargs = Kernel::MakeKernelArgs(args); + const ck_tile::index_t k_grain = args.k_batch * K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); - constexpr dim3 blocks = Kernel::BlockSize(); + float ave_time{0}; - if(!Kernel::IsSupportedArgument(kargs)) + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = GEMM_PIPELINE; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::BatchedGemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << GemmPipelineProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time; + }; + + if(has_hot_loop) { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Odd) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Even) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "Incorrect tail_num for compv3 pipeline! Expected Full, Odd or Even, but got " + << tail_num << "\nPrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) + // Tail pipeline One to Seven + if(tail_num == ck_tile::TailNumber::One) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } - if(s.log_level_ > 0) + if constexpr(BaseGemmPipeline::PrefetchStages > 2) + { + if(tail_num == ck_tile::TailNumber::Two) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 3) + { + if(tail_num == ck_tile::TailNumber::Three) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 4) + { + if(tail_num == ck_tile::TailNumber::Four) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 5) + { + if(tail_num == ck_tile::TailNumber::Five) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 6) + { + if(tail_num == ck_tile::TailNumber::Six) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 7) + { + if(tail_num == ck_tile::TailNumber::Seven) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) + if(tail_num == ck_tile::TailNumber::Three) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } +#endif + } + else { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << CodegenGemmShape::GetName() << '\n' - << "problem: " << CodegenPipelineProblem::GetName() << '\n' - << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Odd) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Even) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + std::ostringstream err; + err << "Incorrect tail_num for pipeline without hotloop, expected Full, Odd or Even, but " + "got " + << tail_num << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); } - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - return ave_time; } diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.hpp b/example/ck_tile/16_batched_gemm/batched_gemm.hpp index 7b7e22160a..0999c7ad3b 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.hpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -9,6 +9,30 @@ #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" +#define CK_TILE_PIPELINE_COMPUTE_V3 1 +#define CK_TILE_PIPELINE_MEMORY 2 +#define CK_TILE_PIPELINE_COMPUTE_V4 3 + +#ifndef CK_TILE_PIPELINE_DEFAULT +#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3 +#endif + +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3 +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3 +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4 +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4 +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave +#else +#error "unsupported CK_TILE_PIPELINE_DEFAULT value" +#endif + template struct BatchedGemmTypeConfig; @@ -32,19 +56,19 @@ using CDataType = Types::CDataType; auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; - arg_parser.insert("m", "256", "m dimension") - .insert("n", "128", "n dimension") - .insert("k", "128", "k dimension") + arg_parser.insert("m", "512", "m dimension") + .insert("n", "1024", "n dimension") + .insert("k", "2048", "k dimension") .insert("stride_a", "0", "Tensor A stride") .insert("stride_b", "0", "Tensor B stride") .insert("stride_c", "0", "Tensor C stride") .insert("a_layout", "R", "A tensor data layout - Row by default") .insert("b_layout", "C", "B tensor data layout - Row by default") .insert("c_layout", "R", "C tensor data layout - Row by default") - .insert("batch_stride_a", "32768", "Batch A stride") - .insert("batch_stride_b", "16384", "Batch B stride") - .insert("batch_stride_c", "32768", "Batch C stride") - .insert("batch_count", "16", "Batch count") + .insert("batch_stride_a", "1048576", "Batch A stride") + .insert("batch_stride_b", "2097152", "Batch B stride") + .insert("batch_stride_c", "524288", "Batch C stride") + .insert("batch_count", "8", "Batch count") .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("warmup", "50", "number of iterations before benchmark the kernel") diff --git a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc index 1105304e3e..16a31e519a 100644 --- a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc +++ b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc @@ -185,7 +185,6 @@ int run_batched_gemm_example_with_layouts(int argc, kbatch, n_warmup, n_repeat); - c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); bool pass = true; diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 03d5818179..2a9903362d 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -16,85 +16,9 @@ #include "ck_tile/host.hpp" #include "grouped_gemm.hpp" -namespace { - -struct GroupedGemmKernelParam -{ - static const bool kPadM = false; - static const bool kPadN = false; - static const bool kPadK = false; - - static const int kBlockPerCu = 1; - static const ck_tile::index_t M_Tile = 128; - static const ck_tile::index_t N_Tile = 128; - static const ck_tile::index_t K_Tile = 32; - - static const ck_tile::index_t M_Warp = 2; - static const ck_tile::index_t N_Warp = 2; - static const ck_tile::index_t K_Warp = 1; - - static const ck_tile::index_t M_Warp_Tile = 32; - static const ck_tile::index_t N_Warp_Tile = 32; - static const ck_tile::index_t K_Warp_Tile = 8; -}; - -using CodegenGemmShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; - -using TilePartitioner = ck_tile::GemmTile1DPartitioner; - -template -using CodegenGemmTraits = ck_tile::TileGemmTraits; - -template -using CodegenPipelineProblem = - ck_tile::GemmPipelineProblem>; - -template -using CodegenGemmPipeline = - ck_tile::GemmPipelineAGmemBGmemCRegV1>; - -template -using GemmEpilogue = ck_tile::CShuffleEpilogue::kBlockSize, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GroupedGemmKernelParam::M_Warp, - GroupedGemmKernelParam::N_Warp, - GroupedGemmKernelParam::M_Warp_Tile, - GroupedGemmKernelParam::N_Warp_Tile, - GroupedGemmKernelParam::K_Warp_Tile, - CodegenPipelineProblem::TransposeC>>; - -template -using Kernel = ck_tile::GroupedGemmKernel, - GemmEpilogue>; -}; // namespace - std::size_t get_workspace_size(const std::vector& gemm_descs) { - return ::Kernel::GetWorkSpaceSize(gemm_descs); + return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); } template @@ -102,37 +26,265 @@ float grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, void* p_workspace_) { - using GroupedGemmKernel = ::Kernel; +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) + // Memory friendly for Interwave scheduler + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 32; + constexpr ck_tile::index_t K_Tile = 64; - auto arguments = GroupedGemmKernel::MakeKargs(gemm_descs); + constexpr ck_tile::index_t M_Warp = 4; + constexpr ck_tile::index_t N_Warp = 1; + constexpr ck_tile::index_t K_Warp = 1; - const dim3 grids = GroupedGemmKernel::GridSize(gemm_descs); - constexpr dim3 blocks = GroupedGemmKernel::BlockSize(); + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 8; - ck_tile::hip_check_error(hipMemcpyWithStream( - p_workspace_, - arguments.data(), - arguments.size() * sizeof(typename GroupedGemmKernel::GemmTransKernelArg), - hipMemcpyHostToDevice, - s.stream_id_)); + constexpr bool DoubleSmemBuffer = false; +#endif +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) + // Compute friendly for Intrawave scheduler + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 64; - if(s.log_level_ > 0) + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool DoubleSmemBuffer = false; +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) + // Compute friendly for Intrawave scheduler + // Using the ping pong reader in the lds level + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 32; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool DoubleSmemBuffer = true; +#endif + + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr bool TransposeC = false; + + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * K_Tile; + const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = GEMM_PIPELINE; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + + const dim3 grids = Kernel::GridSize(gemm_descs); + constexpr dim3 blocks = Kernel::BlockSize(); + + ck_tile::hip_check_error(hipMemcpyWithStream(p_workspace_, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(p_workspace_), + gemm_descs.size())); + return ave_time; + }; + + if(has_hot_loop) { - std::cout << "Launching kernel: " << GroupedGemmKernel::GetName() << " with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Odd) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Even) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "Incorrect tail_num for compv3 pipeline! Expected Full, Odd or Even, but got " + << tail_num << "\nPrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) + // Tail pipeline One to Seven + if(tail_num == ck_tile::TailNumber::One) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + + if constexpr(BaseGemmPipeline::PrefetchStages > 2) + { + if(tail_num == ck_tile::TailNumber::Two) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 3) + { + if(tail_num == ck_tile::TailNumber::Three) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 4) + { + if(tail_num == ck_tile::TailNumber::Four) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 5) + { + if(tail_num == ck_tile::TailNumber::Five) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 6) + { + if(tail_num == ck_tile::TailNumber::Six) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 7) + { + if(tail_num == ck_tile::TailNumber::Seven) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) + if(tail_num == ck_tile::TailNumber::Three) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } +#endif + } + else + { + std::ostringstream err; + err << "Incorrect tail_num for pipeline without hotloop, expected Full, Odd or Even, but " + << "got " << tail_num << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); } - float ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - GroupedGemmKernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(p_workspace_), - gemm_descs.size())); return ave_time; } diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 14d450034d..4fec329c2f 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -9,6 +9,30 @@ #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" +#define CK_TILE_PIPELINE_COMPUTE_V3 1 +#define CK_TILE_PIPELINE_MEMORY 2 +#define CK_TILE_PIPELINE_COMPUTE_V4 3 + +#ifndef CK_TILE_PIPELINE_DEFAULT +#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3 +#endif + +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3 +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3 +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4 +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4 +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave +#else +#error "unsupported CK_TILE_PIPELINE_DEFAULT value" +#endif + template struct GemmTypeConfig; @@ -29,7 +53,7 @@ using BDataType = Types::BDataType; using AccDataType = Types::AccDataType; using CDataType = Types::CDataType; -using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; +using grouped_gemm_kargs = ck_tile::GemmHostArgs; auto create_args(int argc, char* argv[]) { @@ -46,7 +70,7 @@ auto create_args(int argc, char* argv[]) .insert("validate", "1", "0. No validation, 1. Validation on CPU.") .insert("warmup", "10", "number of iterations before benchmark the kernel.") .insert("repeat", "100", "number of iterations to benchmark the kernel.") - .insert("group_count", "16", "group count."); + .insert("group_count", "8", "group count."); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index 080ea818c9..f068510d26 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -101,8 +101,8 @@ int run_grouped_gemm_example_with_layouts(int argc, for(int i = 0; i < group_count; i++) { Ms.push_back(256 + 256 * i); - Ns.push_back(128 + 128 * i); - Ks.push_back(128 + 64 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(256 + 64 * i); stride_As.push_back(Ks[i]); stride_Bs.push_back(Ks[i]); @@ -169,7 +169,10 @@ int run_grouped_gemm_example_with_layouts(int argc, const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer(); void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); - gemm_descs.push_back({p_a, p_b, p_c, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); + // TODO Add support for kbatch > 1 in grouped gemm + static constexpr ck_tile::index_t k_batch = 1; + gemm_descs.push_back( + {p_a, p_b, p_c, k_batch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); } invoke_gemm(warmup, repeat, group_count, gemm_descs); diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index 323c682f2c..dfb6bfae58 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -46,7 +46,7 @@ struct BatchedGemmKernel : public GemmKernel; - using GemmKernelArgs = typename Base::GemmKernelArgs; + using GemmKernelArgs = typename ck_tile::GemmKernelArgs; using ADataType = typename Base::ADataType; using BDataType = typename Base::BDataType; @@ -65,7 +65,7 @@ struct BatchedGemmKernel : public GemmKernel, - concat('x', P_::kMPerBlock, P_::kNPerBlock, P_::kKPerBlock), + concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock), concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()), concat('x', P_::kPadM, P_::kPadN, P_::kPadK)); // clang-format on diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 503a92b863..9435855d0a 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -56,6 +56,20 @@ struct GemmHostArgs : public GemmProblem index_t k_batch; }; +struct GemmKernelArgs +{ + const void* a_ptr; + const void* b_ptr; + void* c_ptr; + index_t M; + index_t N; + index_t K; + index_t stride_A; + index_t stride_B; + index_t stride_C; + index_t k_batch; +}; + template struct GemmKernel { @@ -90,20 +104,6 @@ struct GemmKernel CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } - struct GemmKernelArgs - { - const void* a_ptr; - const void* b_ptr; - void* c_ptr; - index_t M; - index_t N; - index_t K; - index_t stride_A; - index_t stride_B; - index_t stride_C; - index_t k_batch; - }; - CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs) { return GemmKernelArgs{hostArgs.a_ptr, diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 751e7c0e1a..5577cb083a 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -11,24 +11,17 @@ namespace ck_tile { -struct GroupedGemmHostArgs : public ck_tile::GemmHostArgs +struct GemmTransKernelArg { - CK_TILE_HOST GroupedGemmHostArgs() noexcept = default; - CK_TILE_HOST GroupedGemmHostArgs(const void* a_ptr_, - const void* b_ptr_, - void* c_ptr_, - ck_tile::index_t M_, - ck_tile::index_t N_, - ck_tile::index_t K_, - ck_tile::index_t stride_A_, - ck_tile::index_t stride_B_, - ck_tile::index_t stride_C_) - : GemmHostArgs(a_ptr_, b_ptr_, c_ptr_, KBatch, M_, N_, K_, stride_A_, stride_B_, stride_C_) + GemmKernelArgs group_karg; + ck_tile::index_t block_start; + ck_tile::index_t block_end; + + GemmTransKernelArg() = default; + GemmTransKernelArg(GemmKernelArgs&& karg, index_t bl_start, index_t bl_end) + : group_karg{karg}, block_start{bl_start}, block_end{bl_end} { } - - private: - static constexpr index_t KBatch = 1; }; template @@ -47,36 +40,22 @@ struct GroupedGemmKernel : public GemmKernel; using Base = GemmKernel; - using GemmKernelArgs = typename Base::GemmKernelArgs; static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; - struct GemmTransKernelArg - { - GemmKernelArgs group_karg; - ck_tile::index_t block_start; - ck_tile::index_t block_end; - - GemmTransKernelArg() = default; - GemmTransKernelArg(GemmKernelArgs&& karg, index_t bl_start, index_t bl_end) - : group_karg{karg}, block_start{bl_start}, block_end{bl_end} - { - } - }; - [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off using P_ = GemmPipeline; return concat('_', "gemm_grouped", gemm_prec_str, - concat('x', P_::kMPerBlock, P_::kNPerBlock, P_::kKPerBlock), + concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock), concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()), concat('x', P_::kPadM, P_::kPadN, P_::kPadK)); // clang-format on } - __host__ static auto GetWorkSpaceSize(const std::vector& gemm_descs) + __host__ static auto GetWorkSpaceSize(const std::vector& gemm_descs) -> std::size_t { return gemm_descs.size() * sizeof(GemmTransKernelArg); @@ -84,7 +63,7 @@ struct GroupedGemmKernel : public GemmKernel dim3 { return dim3(KernelBlockSize); } - __host__ static constexpr auto GridSize(const std::vector& gemm_descs) + __host__ static constexpr auto GridSize(const std::vector& gemm_descs) { index_t grid_size = 0; for(const auto& it_desc : gemm_descs) @@ -95,7 +74,7 @@ struct GroupedGemmKernel : public GemmKernel& gemm_descs) + CK_TILE_HOST static auto MakeKargs(const std::vector& gemm_descs) -> std::vector { std::vector gemm_kernel_args_; diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc b/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc index f261164d61..74338ba383 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc +++ b/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc @@ -3,7 +3,7 @@ TYPED_TEST(TestCkTileBatchedGemm, Basic) { constexpr int M = 256; - constexpr int N = 128; - constexpr int K = 128; + constexpr int N = 256; + constexpr int K = 512; this->Run(M, N, K); } diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp index 0f787b718d..0af3ef3b34 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp +++ b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp @@ -28,17 +28,9 @@ class TestCkTileBatchedGemm : public ::testing::Test void invoke_batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s) { - // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. - constexpr bool kPadM = false; - constexpr bool kPadN = false; - constexpr bool kPadK = false; - - constexpr int kBlockPerCu = 1; - - // This part comes from the Codegen - constexpr ck_tile::index_t M_Tile = 128; - constexpr ck_tile::index_t N_Tile = 128; - constexpr ck_tile::index_t K_Tile = 32; + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 64; constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; @@ -46,72 +38,144 @@ class TestCkTileBatchedGemm : public ::testing::Test constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 8; + constexpr ck_tile::index_t K_Warp_Tile = 16; - using CodegenGemmShape = + constexpr bool DoubleSmemBuffer = false; + + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr bool TransposeC = false; + + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = ck_tile::TileGemmShape, ck_tile::sequence, ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; - using TilePartitioner = ck_tile::GemmTile1DPartitioner; + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; - using CodegenGemmTraits = - ck_tile::TileGemmTraits; + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; - using CodegenPipelineProblem = ck_tile::GemmPipelineProblem; + const ck_tile::index_t k_grain = args.k_batch * K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + float ave_time{0}; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = - ck_tile::BatchedGemmKernel; + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - auto kargs = Kernel::MakeKernelArgs(args); + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); - constexpr dim3 blocks = Kernel::BlockSize(); + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::BatchedGemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - if(s.log_level_ > 0) + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << GemmPipelineProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time; + }; + + if(has_hot_loop) { - std::cout << "Launching kernel with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "For compute pipeline tail number should always be Full, but have \"" + << tail_num << "\" which is not supported! PrefetchStages: " + << BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + else + { + std::ostringstream err; + err << "Num K loop must be larger than number of prefetech stages." + << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); } - - ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } public: void Run(const int M, const int N, const int K, - int StrideA = 128, - int StrideB = 128, - int StrideC = 128, - const int BatchStrideA = 32768, - const int BatchStrideB = 16384, - const int BatchStrideC = 32768, - const int BatchCount = 16) + int StrideA = 512, + int StrideB = 512, + int StrideC = 256, + const int BatchStrideA = 131072, + const int BatchStrideB = 131072, + const int BatchStrideC = 65536, + const int BatchCount = 8) { using namespace ck_tile::literals; diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm_ut_cases.inc b/test/ck_tile/grouped_gemm/test_grouped_gemm_ut_cases.inc index 68c4693bb3..9f6b66c92b 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm_ut_cases.inc +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm_ut_cases.inc @@ -2,7 +2,7 @@ TYPED_TEST(TestCkTileGroupedGemm, Basic) { - const int group_count = 16; + const int group_count = 8; std::vector Ms; std::vector Ns; std::vector Ks; @@ -13,8 +13,8 @@ TYPED_TEST(TestCkTileGroupedGemm, Basic) for(int i = 0; i < group_count; i++) { Ms.push_back(256 + 256 * i); - Ns.push_back(128 + 128 * i); - Ks.push_back(128 + 64 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(256 + 64 * i); stride_As.push_back(Ks[i]); stride_Bs.push_back(Ks[i]); diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp index cd94d0b867..b125d19762 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp @@ -44,65 +44,10 @@ class TestCkTileGroupedGemm : public ::testing::Test static const ck_tile::index_t K_Warp_Tile = 8; }; - using CodegenGemmShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; - - using TilePartitioner = ck_tile::GemmTile1DPartitioner; - - template - using CodegenGemmTraits = ck_tile::TileGemmTraits; - - template - using CodegenPipelineProblem = - ck_tile::GemmPipelineProblem>; - - template - using CodegenGemmPipeline = - ck_tile::GemmPipelineAGmemBGmemCRegV1>; - - template - using GemmEpilogue = ck_tile::CShuffleEpilogue::BlockSize, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GroupedGemKernelParam::M_Warp, - GroupedGemKernelParam::N_Warp, - GroupedGemKernelParam::M_Warp_Tile, - GroupedGemKernelParam::N_Warp_Tile, - GroupedGemKernelParam::K_Warp_Tile, - CodegenPipelineProblem::TransposeC>>; - - template - using Kernel = ck_tile::GroupedGemmKernel, - GemmEpilogue>; - - using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; - std::size_t GetWorkspaceSize(const std::vector& gemm_descs) + using grouped_gemm_kargs = ck_tile::GemmHostArgs; + std::size_t get_workspace_size(const std::vector& gemm_descs) { - return Kernel::GetWorkSpaceSize(gemm_descs); + return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); } template @@ -110,35 +55,140 @@ class TestCkTileGroupedGemm : public ::testing::Test const ck_tile::stream_config& s, void* p_workspace_) { - using GroupedGemmKernel = Kernel; + constexpr bool DoubleSmemBuffer = false; + constexpr bool TransposeC = false; - auto arguments = GroupedGemmKernel::MakeKargs(gemm_descs); + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; - const dim3 grids = GroupedGemmKernel::GridSize(gemm_descs); - constexpr dim3 blocks = GroupedGemmKernel::BlockSize(); + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; - ck_tile::hip_check_error(hipMemcpyWithStream( - p_workspace_, - arguments.data(), - arguments.size() * sizeof(typename GroupedGemmKernel::GemmTransKernelArg), - hipMemcpyHostToDevice, - s.stream_id_)); + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; - if(s.log_level_ > 0) + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GroupedGemKernelParam::K_Tile; + const ck_tile::index_t K_split = + (gemm_descs[0].K + k_grain - 1) / k_grain * GroupedGemKernelParam::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + + const dim3 grids = Kernel::GridSize(gemm_descs); + constexpr dim3 blocks = Kernel::BlockSize(); + + ck_tile::hip_check_error(hipMemcpyWithStream(p_workspace_, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(p_workspace_), + gemm_descs.size())); + return ave_time; + }; + + if(has_hot_loop) { - std::cout << "Launching kernel with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "For compute pipeline tail number should always be Full, but have \"" + << tail_num << "\" which is not supported! PrefetchStages: " + << BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + else + { + std::ostringstream err; + err << "Num K loop must be larger than number of prefetech stages." + << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); } - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - GroupedGemmKernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(p_workspace_), - gemm_descs.size())); } public: @@ -243,12 +293,14 @@ class TestCkTileGroupedGemm : public ::testing::Test const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer(); void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); + // TODO add support for kbatch > 1 + static constexpr ck_tile::index_t k_batch = 1; gemm_descs.push_back( - {p_a, p_b, p_c, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); + {p_a, p_b, p_c, k_batch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); } ck_tile::DeviceMem gemm_workspace; - gemm_workspace.Realloc(GetWorkspaceSize(gemm_descs)); + gemm_workspace.Realloc(get_workspace_size(gemm_descs)); invoke_grouped_gemm( gemm_descs, ck_tile::stream_config{nullptr, false}, gemm_workspace.GetDeviceBuffer());