mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4834 (commit e75e6cb)
[CK_TILE][GEMM] Fix eightwarp error & Add eightwarp unit test (#4834) ## Motivation The primary goal of this PR is to fix a critical issue in the EightWarps implementation within ck_tile. Additionally, unit tests were added to ensure that CI can detect errors. ## Test Plan ninja test_tile_gemm_quant_abquant_eightwarps ./bin/test_tile_gemm_quant_abquant_eightwarps ## Test Result All EightWarps related test cases in TestCkTileGemmABQuant completed successfully without linker errors or validation mismatches. ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
b09112bbad
commit
30702c9cbc
@@ -81,6 +81,11 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_abquant_preshuffle_preshuffleQuant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_abquant_eightwarps
|
||||
test_gemm_quant_abquant_eightwarps.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_abquant_eightwarps PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# ABQuant split-K tests
|
||||
add_gtest_executable(test_tile_gemm_quant_abquant_splitk_decode
|
||||
test_gemm_quant_abquant_splitk_decode.cpp
|
||||
@@ -275,6 +280,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
test_tile_gemm_quant_abquant_a4w4_base
|
||||
test_tile_gemm_quant_abquant_a4w4_padding
|
||||
test_tile_gemm_quant_abquant_a4w4_preshuffle
|
||||
test_tile_gemm_quant_abquant_eightwarps
|
||||
# ABQuant split-K tests
|
||||
test_tile_gemm_quant_abquant_splitk_decode
|
||||
test_tile_gemm_quant_abquant_splitk_prefill
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using ABQuantGrouped =
|
||||
std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// 2d block sizes for BQuant
|
||||
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
#ifdef CK_GFX950_SUPPORT
|
||||
// Type combinations for ABQuant tests
|
||||
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
|
||||
// clang-format off
|
||||
using ABQuantEightWarpsTypes = ::testing::Types<
|
||||
// PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigEightWarps, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigEightWarps_PreshuffleB, GroupSize, GroupSize2D128N, ColumnMajor>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for ABQuant
|
||||
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantEightWarpsTypes);
|
||||
|
||||
// AQuant tests
|
||||
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
#endif
|
||||
@@ -29,10 +29,11 @@ using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>
|
||||
// clang-format off
|
||||
using ABQuantPreshuffleBTypes = ::testing::Types<
|
||||
// 1D B-scales; PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize, GroupSize, ColumnMajor>,
|
||||
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize, GroupSize, ColumnMajor>,
|
||||
/// 2D B-scales; PreshuffleQuant = false && TransposeC = true (RCR layout with RowMajor AQ)
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPrefillTransposeC, GroupSize, GroupSize2D128N, ColumnMajor>
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPrefillTransposeC, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize, GroupSize2D128N, ColumnMajor>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleB_ABQuant_Prefill, GroupSize, GroupSize2D128N, ColumnMajor>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -188,6 +188,33 @@ struct GemmConfigPreshuffleBPreshuffleQuantDecode : public GemmConfigPreshuffleB
|
||||
static constexpr bool TransposeC = TransposeC_;
|
||||
};
|
||||
|
||||
struct GemmConfigPreshuffleB_ABQuant_Prefill : public GemmConfigPreshuffleBPrefill
|
||||
{
|
||||
static constexpr bool TransposeC = true;
|
||||
};
|
||||
|
||||
struct GemmConfigEightWarps : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 2; // NWarps == 2 for ping-pong!
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Tile = 192;
|
||||
static constexpr ck_tile::index_t N_Tile = 128 * N_Warp;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 * K_Warp;
|
||||
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<ck_tile::fp8_t, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr bool kPadK = false;
|
||||
static constexpr bool TransposeC = true;
|
||||
};
|
||||
|
||||
struct GemmConfigEightWarps_PreshuffleB : public GemmConfigEightWarps
|
||||
{
|
||||
static constexpr bool PreshuffleB = true;
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmAQuant<Tuple>>
|
||||
{
|
||||
@@ -1186,6 +1213,22 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
|
||||
void run_quant_gemm_impl(const ck_tile::QuantGemmHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
|
||||
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
constexpr bool IS_FP8BLOCKSCALE = BQuantGroupSize::kN == 128 &&
|
||||
(std::is_same_v<ADataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<ADataType, ck_tile::bf8_t>) &&
|
||||
(std::is_same_v<BDataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<BDataType, ck_tile::bf8_t>);
|
||||
constexpr bool transpose_c = CodegenGemmTraits::TransposeC;
|
||||
constexpr bool eight_warps =
|
||||
#ifdef CK_GFX950_SUPPORT
|
||||
IS_FP8BLOCKSCALE &&
|
||||
(GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp == 8) &&
|
||||
GemmConfig::K_Warp_Tile == 128;
|
||||
#else
|
||||
false;
|
||||
#endif
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
@@ -1193,20 +1236,27 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
|
||||
CodegenGemmTraits,
|
||||
ComputeDataType>;
|
||||
|
||||
using BaseGemmPipeline = std::conditional_t<
|
||||
PreshuffleB == true,
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
|
||||
constexpr auto base_gemm_pipeline = []() {
|
||||
if constexpr(eight_warps)
|
||||
return ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>{};
|
||||
else if constexpr(PreshuffleB)
|
||||
return ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>{};
|
||||
else if constexpr(IS_FP8BLOCKSCALE)
|
||||
return ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>{};
|
||||
else
|
||||
return ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>{};
|
||||
}();
|
||||
using BaseGemmPipeline = std::decay_t<decltype(base_gemm_pipeline)>;
|
||||
|
||||
const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::index_t K_split =
|
||||
ck_tile::integer_least_multiple(args.K, GemmConfig::K_Tile);
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr bool transpose_c = CodegenGemmTraits::TransposeC;
|
||||
|
||||
using PipelineProblem =
|
||||
ck_tile::GemmABQuantPipelineProblem<ADataType,
|
||||
@@ -1224,10 +1274,12 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline =
|
||||
std::conditional_t<PreshuffleB == true,
|
||||
using GemmPipeline = std::conditional_t<
|
||||
eight_warps,
|
||||
ck_tile::ABQuantGemmPipelineAgBgCrAsync<PipelineProblem>,
|
||||
std::conditional_t<PreshuffleB,
|
||||
ck_tile::WPABQuantBPipelineAgBgCrV2<PipelineProblem>,
|
||||
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>;
|
||||
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<typename PipelineProblem::ComputeDataType,
|
||||
@@ -1264,9 +1316,9 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
|
||||
{
|
||||
throw std::runtime_error("Arguments not supported for ABQuant kernel");
|
||||
}
|
||||
|
||||
using k_attr_t = ck_tile::kernel_attr<eight_warps>;
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfigBase::kBlockPerCu>(
|
||||
ck_tile::make_kernel<GemmConfigBase::kBlockPerCu, k_attr_t>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user