diff --git a/example/ck_tile/17_grouped_gemm/CMakeLists.txt b/example/ck_tile/17_grouped_gemm/CMakeLists.txt index f97cc03d2a..1a833df6c2 100644 --- a/example/ck_tile/17_grouped_gemm/CMakeLists.txt +++ b/example/ck_tile/17_grouped_gemm/CMakeLists.txt @@ -1,10 +1,9 @@ add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp) add_executable(tile_example_quant_grouped_gemm EXCLUDE_FROM_ALL quant_grouped_gemm.cpp) add_executable(tile_example_grouped_gemm_preshuffle EXCLUDE_FROM_ALL grouped_gemm_preshuffle.cpp) - - -set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS) +set(EXAMPLE_GEMM_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) - list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() -target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS}) \ No newline at end of file +target_compile_options(tile_example_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 9975f2024b..606d98d9e2 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -356,6 +356,8 @@ int main(int argc, char* argv[]) #if CK_TILE_USE_WMMA return !run_grouped_gemm_example(argc, argv); #else - return !run_grouped_gemm_example(argc, argv); + return !run_grouped_gemm_example(argc, argv) || + !run_grouped_gemm_example(argc, argv) || + !run_grouped_gemm_example(argc, argv); #endif } diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 1ae0844032..6493a542ba 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -91,7 +91,7 @@ struct GemmConfigBase static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool Preshuffle = false; - static constexpr bool Persistent = false; + static constexpr bool Persistent = true; static constexpr bool DoubleSmemBuffer = false; }; @@ -139,6 +139,29 @@ struct GemmConfigComputeV4 : public GemmConfigBase static constexpr int kBlockPerCu = 2; }; +template +struct GemmConfigComputeV4_V2 : public GemmConfigBase +{ + // Compute V4 only support Intrawave scheduler + // Using the ping pong reader in the lds level + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + + static constexpr int kBlockPerCu = 2; +}; + template struct GemmConfigPreshuffleDecode : public GemmConfigBase { diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index dda38bbc47..e38e49f5d1 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -292,34 +292,8 @@ struct GroupedGemmKernel { __shared__ char smem_ptr_1[GetSmemSize()]; - if constexpr(UsePersistentKernel || GemmPipeline::Preshuffle) - { - - RunGemmWithPipelineSelection2LDS(a_ptr, - b_ptr, - c_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - return; - } - else - { - - Base::RunGemm2LDS({a_ptr}, - {b_ptr}, - {/*ds_ptr*/}, - c_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } + RunGemmWithPipelineSelection2LDS( + a_ptr, b_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, splitk_batch_offset, i_m, i_n); } else // SingleSmemBuffer {