Support A/B Quantization in Blockscale GEMM (#3343)

* Support A/B Quantization in Blockscale GEMM

* Support A/B Quantization in Blockscale GEMM

* Support A/B Quantization in Blockscale GEMM

* Support A/B Quantization in Blockscale GEMM

* Support A/B Quantization in Blockscale GEMM

* Implement review suggested changes

* Implement review suggested changes

* Sync with develop

* fix pre-commit error

* Add unit tests for blockscale AB-Quantization

* fix pre-commit error

* fix pre-commit error

* fix compile error

* fix compile error

* fix clang-format

* fix clang-format

* fix enumeration values not handled in switch

* rebase file

* Add missing enums to data_type_sizeof (#3430)

Fixes broken build on gfx942. This was some test code that got merged at the same time.

* [CK_BUILDER] CK Tile header installation for builder, algorithm concept improvements (#3419)

* Added install of CK_Tile headers when using CK_EXPERIMENTAL_BUILDER. MIOpen needs this since the builder uses features from CK Tile and the CK Tile install is excluded when doing a narrow build for MIOpen
* Changed algorithm concept type checks to be concepts instead of constexpr bool functions. This improves compiler error messages when using these concepts in static_asserts

---------

Co-authored-by: Daryl Hawkins <DarylHawkins@amd.com>

* Add build trace diagnostics to CI. (#3432)

* generate and visualize build traces for all archs

* generate build traces in all cases

* fix jenkins logic

* fix typo

* use more threads for parsing dependency map

* add script to parse ninja traces and issue warnings

* fix python script syntax and header

* fix python syntax one more time

* fix python syntax

* Support A/B Quantization in Blockscale GEMM

* Implement review suggested changes

* Sync with develop

* Add unit tests for blockscale AB-Quantization

* fix enumeration values not handled in switch

* rebase file

* rebase file

---------

Co-authored-by: John Shumway <jshumway@amd.com>
Co-authored-by: DarylHawkinsAMD <Daryl.Hawkins@amd.com>
Co-authored-by: Daryl Hawkins <DarylHawkins@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
kensclin
2025-12-17 23:13:47 +08:00
committed by GitHub
parent 292df2719f
commit 0500fcc017
30 changed files with 2318 additions and 353 deletions

6
test/ck_tile/gemm_block_scale/CMakeLists.txt Executable file → Normal file
View File

@@ -25,6 +25,12 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
add_gtest_executable(test_tile_gemm_quant_aquant_base_ccr
test_gemm_quant_aquant_base_ccr.cpp
)
# ABQuant tests
add_gtest_executable(test_tile_gemm_quant_abquant
test_gemm_quant_abquant.cpp
)
target_compile_options(test_tile_gemm_quant_abquant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
target_compile_options(test_tile_gemm_quant_aquant_base_ccr PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_aquant_prefill

View File

@@ -0,0 +1,55 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using ABQuantGrouped =
std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// 2d block sizes for BQuant
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
// Type combinations for ABQuant tests
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
// clang-format off
using ABQuantTypes = ::testing::Types<
// PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ)
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>
>;
// clang-format on
// Test suite for ABQuant
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes);
// AQuant tests
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -21,6 +21,24 @@
template <ck_tile::QuantType QT>
struct QuantTypeTraits;
template <typename TTuple, size_t Index, typename DefaultType, typename Enable = void>
struct SafeTupleElement
{
using type = DefaultType;
};
template <typename TTuple, size_t Index, typename DefaultType>
struct SafeTupleElement<TTuple,
Index,
DefaultType,
std::enable_if_t<(Index < std::tuple_size_v<TTuple>)>>
{
using type = std::tuple_element_t<Index, TTuple>;
};
template <typename TTuple, size_t Index, typename DefaultType>
using SafeTupleElement_t = typename SafeTupleElement<TTuple, Index, DefaultType>::type;
// Base class for common quant gemm functionality
template <typename Tuple, typename Derived>
class TestCkTileGemmQuantBase : public ::testing::Test
@@ -37,6 +55,9 @@ class TestCkTileGemmQuantBase : public ::testing::Test
static constexpr auto QuantType = std::tuple_element_t<8, Tuple>::value;
using GemmConfig = std::tuple_element_t<9, Tuple>;
using QuantGroupSize = std::tuple_element_t<10, Tuple>;
using AQuantGroupSize = QuantGroupSize;
using BQuantGroupSize = SafeTupleElement_t<Tuple, 11, QuantGroupSize>;
using BQLayout = SafeTupleElement_t<Tuple, 12, AQLayout>;
using AccDataType = float; // accumulate always in float
// Get the quant-type specific data types from traits
@@ -86,9 +107,6 @@ class TestCkTileGemmQuantBase : public ::testing::Test
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
// Re-use the AQLayout for BQLayout
using BQLayout = AQLayout;
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
kPadN,
kPadK,
@@ -155,7 +173,8 @@ class TestCkTileGemmQuantBase : public ::testing::Test
template <ck_tile::QuantType QT>
struct QuantTypeTraits
{
static_assert(QT == ck_tile::QuantType::AQuantGrouped ||
static_assert(QT == ck_tile::QuantType::ABQuantGrouped ||
QT == ck_tile::QuantType::AQuantGrouped ||
QT == ck_tile::QuantType::BQuantGrouped ||
QT == ck_tile::QuantType::RowColQuant ||
QT == ck_tile::QuantType::TensorQuant,
@@ -182,6 +201,16 @@ struct QuantTypeTraits<ck_tile::QuantType::BQuantGrouped>
static constexpr const char* name = "bquant";
};
// Specialization for ABQuantGrouped
template <>
struct QuantTypeTraits<ck_tile::QuantType::ABQuantGrouped>
{
template <typename ADataType, typename BDataType>
using ComputeDataType = BDataType; // For AQuant, compute type is BDataType
static constexpr const char* name = "abquant";
};
// Specialization for RowColQuant
template <>
struct QuantTypeTraits<ck_tile::QuantType::RowColQuant>

View File

@@ -664,6 +664,314 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
}
};
// ABQuant-specific test fixture
template <typename Tuple>
class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmABQuant<Tuple>>
{
using Base = TestCkTileGemmQuantBase<Tuple, TestCkTileGemmABQuant<Tuple>>;
friend Base;
public:
using typename Base::AccDataType;
using typename Base::ADataType;
using typename Base::ALayout;
using typename Base::AQLayout;
using typename Base::AQuantGroupSize;
using typename Base::BDataType;
using typename Base::BLayout;
using typename Base::BQuantGroupSize;
using typename Base::CDataType;
using typename Base::CLayout;
using typename Base::ComputeDataType;
using typename Base::GemmConfig;
using typename Base::QDataType;
using BQLayout = ck_tile::tensor_layout::gemm::ColumnMajor;
static constexpr auto QuantType = Base::QuantType;
static constexpr auto PreshuffleB = Base::PreshuffleB;
static constexpr auto TiledMMAPermuteN = Base::TiledMMAPermuteN;
protected:
void SetUpQuantTypeSpecific() {}
void TearDownQuantTypeSpecific() {}
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
{
const ck_tile::index_t stride_A =
ck_tile::get_default_stride(M, K, 0, this->is_row_major(ALayout{}));
const ck_tile::index_t stride_B =
ck_tile::get_default_stride(K, N, 0, this->is_row_major(BLayout{}));
const ck_tile::index_t stride_C =
ck_tile::get_default_stride(M, N, 0, this->is_row_major(CLayout{}));
// AQuant uses grouped quantization for A matrix
const ck_tile::index_t AQK = ck_tile::integer_divide_ceil(K, AQuantGroupSize::kK);
// BQuant uses block/grouped quantization for B matrix
const ck_tile::index_t BQN = ck_tile::integer_divide_ceil(N, BQuantGroupSize::kN);
const ck_tile::index_t BQK = ck_tile::integer_divide_ceil(K, BQuantGroupSize::kK);
const ck_tile::index_t stride_AQ =
ck_tile::get_default_stride(M, AQK, 0, this->is_row_major(AQLayout{}));
const ck_tile::index_t stride_BQ =
ck_tile::get_default_stride(BQK, BQN, 0, this->is_row_major(BQLayout{}));
// Generate test data
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
// AQLayout is independently specified for each test case
ck_tile::HostTensor<QDataType> aq_m_aqk( // AQDataType
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, this->is_row_major(AQLayout{})));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{})));
ck_tile::HostTensor<QDataType> bq_bqk_bqn(
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, this->is_row_major(BQLayout{})));
// Initialize data with random values
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f}(a_m_k);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f}(a_m_k);
}
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f}(b_k_n);
ck_tile::FillUniformDistribution<QDataType>{-2.0f, 2.0f}(aq_m_aqk);
ck_tile::FillUniformDistribution<QDataType>{-2.0f, 2.0f}(bq_bqk_bqn);
// Allocate device memory
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType));
ck_tile::DeviceMem aq_m_aqk_dev_buf(aq_m_aqk.get_element_space_size() *
sizeof(QDataType)); // AQDataType
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType));
ck_tile::DeviceMem bq_bqk_bqn_dev_buf(bq_bqk_bqn.get_element_space_size() *
sizeof(QDataType));
ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType));
// Copy to device
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<ADataType> temp = a_m_k;
ck_tile::permute_vectors_i4x4_b(temp);
a_m_k_dev_buf.ToDevice(temp.data());
}
else
{
a_m_k_dev_buf.ToDevice(a_m_k.data());
}
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
if constexpr(PreshuffleB)
{
if constexpr(TiledMMAPermuteN && BQuantGroupSize::kN == 1)
{
printf("PreshuffleB with TiledMMAPermuteN\n");
b_k_n_dev = ck_tile::shuffle_b_permuteN<GemmConfig>(b_k_n);
}
else
{
printf("PreshuffleB without TiledMMAPermuteN\n");
b_k_n_dev = ck_tile::shuffle_b<GemmConfig>(b_k_n);
}
}
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
ck_tile::permute_vectors_i4x4_b(b_k_n_dev);
}
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
if constexpr(Base::GemmConfig::PreshuffleQuant)
{
ck_tile::HostTensor<QDataType> aq_shuffle_host =
ck_tile::shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / AQuantGroupSize::kK);
aq_m_aqk_dev_buf.ToDevice(aq_shuffle_host.data());
}
else
{
aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
}
if constexpr(PreshuffleB && TiledMMAPermuteN && BQuantGroupSize::kN == 1)
{
printf("Preshuffle BQ with TiledMMAPermuteN \n");
ck_tile::HostTensor<QDataType> bq_shuffle_host =
ck_tile::bq_permuteN<GemmConfig>(bq_bqk_bqn, BQuantGroupSize::kN);
bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data());
}
else if constexpr(GemmConfig::PreshuffleQuant)
{
ck_tile::HostTensor<QDataType> bq_shuffle_host =
ck_tile::shuffle_bq(&bq_bqk_bqn, GemmConfig::K_Tile / BQuantGroupSize::kK);
bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data());
}
else
{
bq_bqk_bqn_dev_buf.ToDevice(bq_bqk_bqn.data());
}
// Create args for kernel execution
ck_tile::QuantGemmHostArgs args{
a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr
b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
aq_m_aqk_dev_buf.GetDeviceBuffer(), // aq_ptr (scales)
bq_bqk_bqn_dev_buf.GetDeviceBuffer(), // bq_ptr (scales)
1, // k_batch
M,
N,
K, // M, N, K
AQK, // QK_A
BQK, // QK_B
stride_A,
stride_B,
stride_C,
stride_AQ,
stride_BQ // strides
};
// Run the kernel
ck_tile::stream_config stream_config{};
this->invoke_quant_gemm(args, stream_config);
// Validation using reference implementation
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
// Run reference ABQuant implementation
ck_tile::reference_gemm_abquant<ADataType,
QDataType, // AQDataType
BDataType,
QDataType,
AccDataType,
CDataType,
AQuantGroupSize,
BQuantGroupSize>(
a_m_k, aq_m_aqk, b_k_n, bq_bqk_bqn, c_m_n_host_ref);
// Get device result
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data());
// Calculate error tolerances
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol =
this->template calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, 1, max_accumulated_value);
// Validate results
bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
EXPECT_TRUE(pass) << "ABQuantGrouped validation failed with M=" << M << ", N=" << N
<< ", K=" << K;
if(!pass)
{
std::cout << "ABQuantGrouped - Relative error threshold: "
<< rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
}
}
private:
// ABQuant-specific pipeline implementation
template <typename CodegenGemmShape, typename TilePartitioner, typename CodegenGemmTraits>
void run_quant_gemm_impl(const ck_tile::QuantGemmHostArgs& args,
const ck_tile::stream_config& s)
{
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
ComputeDataType>;
using BaseGemmPipeline =
std::conditional_t<PreshuffleB == false,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr bool transpose_c = CodegenGemmTraits::TransposeC;
using PipelineProblem =
ck_tile::GemmABQuantPipelineProblem<ADataType,
QDataType, // AQDataType
BDataType,
QDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
AQuantGroupSize,
BQuantGroupSize,
transpose_c,
ComputeDataType,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline =
std::conditional_t<PreshuffleB == false,
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
Base::M_Warp,
Base::N_Warp,
Base::M_Warp_Tile,
Base::N_Warp_Tile,
Base::K_Warp_Tile,
transpose_c,
ck_tile::memory_operation_enum::set,
1,
false,
1,
TiledMMAPermuteN>>;
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
GemmPipeline,
GemmEpilogue,
ck_tile::QuantType::ABQuantGrouped>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Arguments not supported for ABQuant kernel");
}
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfigBase::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
};
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
}
};
template <typename Tuple>
class TestCkTileGemmPreshuffleBBQuant : public TestCkTileGemmBQuant<Tuple>
{