[rocm-libraries] ROCm/rocm-libraries#4834 (commit e75e6cb)

[CK_TILE][GEMM] Fix eightwarp error & Add eightwarp unit test
 (#4834)

## Motivation

The primary goal of this PR is to fix a critical issue in the EightWarps
implementation within ck_tile. Additionally, unit tests were added to
ensure that CI can detect errors.

## Test Plan

ninja test_tile_gemm_quant_abquant_eightwarps
./bin/test_tile_gemm_quant_abquant_eightwarps

## Test Result

All EightWarps related test cases in TestCkTileGemmABQuant completed
successfully without linker errors or validation mismatches.

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
kensclin
2026-03-04 04:11:27 +00:00
committed by assistant-librarian[bot]
parent b09112bbad
commit 30702c9cbc
7 changed files with 124 additions and 18 deletions

View File

@@ -81,6 +81,11 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
)
target_compile_options(test_tile_gemm_quant_abquant_preshuffle_preshuffleQuant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_abquant_eightwarps
test_gemm_quant_abquant_eightwarps.cpp
)
target_compile_options(test_tile_gemm_quant_abquant_eightwarps PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
# ABQuant split-K tests
add_gtest_executable(test_tile_gemm_quant_abquant_splitk_decode
test_gemm_quant_abquant_splitk_decode.cpp
@@ -275,6 +280,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
test_tile_gemm_quant_abquant_a4w4_base
test_tile_gemm_quant_abquant_a4w4_padding
test_tile_gemm_quant_abquant_a4w4_preshuffle
test_tile_gemm_quant_abquant_eightwarps
# ABQuant split-K tests
test_tile_gemm_quant_abquant_splitk_decode
test_tile_gemm_quant_abquant_splitk_prefill

View File

@@ -0,0 +1,45 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using ABQuantGrouped =
std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// 2d block sizes for BQuant
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
#ifdef CK_GFX950_SUPPORT
// Type combinations for ABQuant tests
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
// clang-format off
using ABQuantEightWarpsTypes = ::testing::Types<
// PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ)
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigEightWarps, GroupSize, GroupSize2D128N, ColumnMajor>,
std::tuple<RowMajor, ColumnMajor, RowMajor, ColumnMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigEightWarps_PreshuffleB, GroupSize, GroupSize2D128N, ColumnMajor>
>;
// clang-format on
// Test suite for ABQuant
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantEightWarpsTypes);
// AQuant tests
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}
#endif

View File

@@ -29,10 +29,11 @@ using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>
// clang-format off
using ABQuantPreshuffleBTypes = ::testing::Types<
// 1D B-scales; PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ)
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize, GroupSize, ColumnMajor>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize, GroupSize, ColumnMajor>,
/// 2D B-scales; PreshuffleQuant = false && TransposeC = true (RCR layout with RowMajor AQ)
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPrefillTransposeC, GroupSize, GroupSize2D128N, ColumnMajor>
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPrefillTransposeC, GroupSize, GroupSize2D128N, ColumnMajor>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPrefill, GroupSize, GroupSize2D128N, ColumnMajor>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleB_ABQuant_Prefill, GroupSize, GroupSize2D128N, ColumnMajor>
>;
// clang-format on

View File

@@ -188,6 +188,33 @@ struct GemmConfigPreshuffleBPreshuffleQuantDecode : public GemmConfigPreshuffleB
static constexpr bool TransposeC = TransposeC_;
};
struct GemmConfigPreshuffleB_ABQuant_Prefill : public GemmConfigPreshuffleBPrefill
{
static constexpr bool TransposeC = true;
};
struct GemmConfigEightWarps : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 2; // NWarps == 2 for ping-pong!
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Tile = 192;
static constexpr ck_tile::index_t N_Tile = 128 * N_Warp;
static constexpr ck_tile::index_t K_Tile = 128 * K_Warp;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<ck_tile::fp8_t, M_Warp_Tile, true>();
static constexpr bool kPadK = false;
static constexpr bool TransposeC = true;
};
struct GemmConfigEightWarps_PreshuffleB : public GemmConfigEightWarps
{
static constexpr bool PreshuffleB = true;
};
template <typename Tuple>
class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmAQuant<Tuple>>
{
@@ -1186,6 +1213,22 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
void run_quant_gemm_impl(const ck_tile::QuantGemmHostArgs& args,
const ck_tile::stream_config& s)
{
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr bool IS_FP8BLOCKSCALE = BQuantGroupSize::kN == 128 &&
(std::is_same_v<ADataType, ck_tile::fp8_t> ||
std::is_same_v<ADataType, ck_tile::bf8_t>) &&
(std::is_same_v<BDataType, ck_tile::fp8_t> ||
std::is_same_v<BDataType, ck_tile::bf8_t>);
constexpr bool transpose_c = CodegenGemmTraits::TransposeC;
constexpr bool eight_warps =
#ifdef CK_GFX950_SUPPORT
IS_FP8BLOCKSCALE &&
(GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp == 8) &&
GemmConfig::K_Warp_Tile == 128;
#else
false;
#endif
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
AccDataType,
@@ -1193,20 +1236,27 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
CodegenGemmTraits,
ComputeDataType>;
using BaseGemmPipeline = std::conditional_t<
PreshuffleB == true,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
constexpr auto base_gemm_pipeline = []() {
if constexpr(eight_warps)
return ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>{};
else if constexpr(PreshuffleB)
return ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>{};
else if constexpr(IS_FP8BLOCKSCALE)
return ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>{};
else
return ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>{};
}();
using BaseGemmPipeline = std::decay_t<decltype(base_gemm_pipeline)>;
const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::index_t K_split =
ck_tile::integer_least_multiple(args.K, GemmConfig::K_Tile);
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr bool transpose_c = CodegenGemmTraits::TransposeC;
using PipelineProblem =
ck_tile::GemmABQuantPipelineProblem<ADataType,
@@ -1224,10 +1274,12 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
has_hot_loop_v,
tail_number_v>;
using GemmPipeline =
std::conditional_t<PreshuffleB == true,
using GemmPipeline = std::conditional_t<
eight_warps,
ck_tile::ABQuantGemmPipelineAgBgCrAsync<PipelineProblem>,
std::conditional_t<PreshuffleB,
ck_tile::WPABQuantBPipelineAgBgCrV2<PipelineProblem>,
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>;
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<typename PipelineProblem::ComputeDataType,
@@ -1264,9 +1316,9 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
{
throw std::runtime_error("Arguments not supported for ABQuant kernel");
}
using k_attr_t = ck_tile::kernel_attr<eight_warps>;
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfigBase::kBlockPerCu>(
ck_tile::make_kernel<GemmConfigBase::kBlockPerCu, k_attr_t>(
Kernel{}, grids, blocks, 0, kargs));
};