diff --git a/CHANGELOG.md b/CHANGELOG.md index c914224bb3..e1feebcb3e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ### Added * Added overload of load_tile_transpose that takes reference to output tensor as output parameter * Use data type from LDS tensor view when determining tile distribution for transpose in the GEMM pipeline +* Added eightwarps support for abquant mode in blockscale GEMM. * Added preshuffleB support for abquant mode in blockscale GEMM. * Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight. * Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32". diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 7ebfa412f7..33a31b6557 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -116,7 +116,7 @@ struct CShuffleEpilogue static constexpr index_t isCTransposed = Problem::isCTransposed; static constexpr bool FixedVectorSize = Problem::FixedVectorSize; static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN; -#ifdef __gfx95__ +#if defined(CK_GFX950_SUPPORT) static constexpr bool EightWave = (MWave * NWave == 8); #else static constexpr bool EightWave = false; diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 486c1836ea..5e7fb0e4da 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -1328,7 +1328,8 @@ struct QuantGemmKernel // For RowMajor C, M is the row dimension — check M alignment here because // ALayout=RowMajor does not check M (it only checks K), leaving a gap for // the RowMajorA + RowMajorC combination. - if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) + if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false && + GemmPipeline::BlockGemmShape::NumWarps != 8) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 2dcf0bca0c..14748a9d1b 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -81,6 +81,11 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") ) target_compile_options(test_tile_gemm_quant_abquant_preshuffle_preshuffleQuant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_tile_gemm_quant_abquant_eightwarps + test_gemm_quant_abquant_eightwarps.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_eightwarps PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + # ABQuant split-K tests add_gtest_executable(test_tile_gemm_quant_abquant_splitk_decode test_gemm_quant_abquant_splitk_decode.cpp @@ -275,6 +280,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") test_tile_gemm_quant_abquant_a4w4_base test_tile_gemm_quant_abquant_a4w4_padding test_tile_gemm_quant_abquant_a4w4_preshuffle + test_tile_gemm_quant_abquant_eightwarps # ABQuant split-K tests test_tile_gemm_quant_abquant_splitk_decode test_tile_gemm_quant_abquant_splitk_prefill diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_eightwarps.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_eightwarps.cpp new file mode 100644 index 0000000000..03b7cf8b03 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_eightwarps.cpp @@ -0,0 +1,45 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using ABQuantGrouped = + std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D128N = ck_tile::QuantGroupShape>; +#ifdef CK_GFX950_SUPPORT +// Type combinations for ABQuant tests +// Tuple format: +// clang-format off +using ABQuantEightWarpsTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for ABQuant +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantEightWarpsTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} +#endif diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp index a317a413ce..7d8b62616e 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp @@ -29,10 +29,11 @@ using GroupSize2D128N = ck_tile::QuantGroupShape> // clang-format off using ABQuantPreshuffleBTypes = ::testing::Types< // 1D B-scales; PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) - std::tuple, - + std::tuple, /// 2D B-scales; PreshuffleQuant = false && TransposeC = true (RCR layout with RowMajor AQ) - std::tuple + std::tuple, + std::tuple, + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 1266fa5889..8fc92d588a 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -188,6 +188,33 @@ struct GemmConfigPreshuffleBPreshuffleQuantDecode : public GemmConfigPreshuffleB static constexpr bool TransposeC = TransposeC_; }; +struct GemmConfigPreshuffleB_ABQuant_Prefill : public GemmConfigPreshuffleBPrefill +{ + static constexpr bool TransposeC = true; +}; + +struct GemmConfigEightWarps : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 2; // NWarps == 2 for ping-pong! + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Tile = 192; + static constexpr ck_tile::index_t N_Tile = 128 * N_Warp; + static constexpr ck_tile::index_t K_Tile = 128 * K_Warp; + + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); + + static constexpr bool kPadK = false; + static constexpr bool TransposeC = true; +}; + +struct GemmConfigEightWarps_PreshuffleB : public GemmConfigEightWarps +{ + static constexpr bool PreshuffleB = true; +}; + template class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase> { @@ -1186,6 +1213,22 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase); + constexpr bool IS_FP8BLOCKSCALE = BQuantGroupSize::kN == 128 && + (std::is_same_v || + std::is_same_v) && + (std::is_same_v || + std::is_same_v); + constexpr bool transpose_c = CodegenGemmTraits::TransposeC; + constexpr bool eight_warps = +#ifdef CK_GFX950_SUPPORT + IS_FP8BLOCKSCALE && + (GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp == 8) && + GemmConfig::K_Warp_Tile == 128; +#else + false; +#endif using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase; - using BaseGemmPipeline = std::conditional_t< - PreshuffleB == true, - ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2, - ck_tile::BaseGemmPipelineAgBgCrCompV3>; + constexpr auto base_gemm_pipeline = []() { + if constexpr(eight_warps) + return ck_tile::BaseGemmPipelineAgBgCrCompV3{}; + else if constexpr(PreshuffleB) + return ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2{}; + else if constexpr(IS_FP8BLOCKSCALE) + return ck_tile::BaseGemmPipelineAgBgCrCompV3{}; + else + return ck_tile::BaseGemmPipelineAgBgCrCompV3{}; + }(); + using BaseGemmPipeline = std::decay_t; - const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::index_t K_split = + ck_tile::integer_least_multiple(args.K, GemmConfig::K_Tile); + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr auto tail_number_v = tail_number_.value; - constexpr bool transpose_c = CodegenGemmTraits::TransposeC; using PipelineProblem = ck_tile::GemmABQuantPipelineProblem; - using GemmPipeline = - std::conditional_t, + std::conditional_t, - ck_tile::ABQuantGemmPipelineAgBgCrCompV3>; + ck_tile::ABQuantGemmPipelineAgBgCrCompV3>>; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem; ck_tile::launch_kernel(s, - ck_tile::make_kernel( + ck_tile::make_kernel( Kernel{}, grids, blocks, 0, kargs)); };