[CK TILE GEMM] Support Aquant GEMM with transposeC and preshuffle (#2897)

* [CK TILE GEMM] Support Aquant GEMM with transposeC and preshuffle

When TransposeC and QuantPreshuffle are both true, Aquant generates
correct result.

* [CK TILE GEMM] Support Aquant GEMM with transposeC and preshuffle

- Add unit tests

* Fix bug in is_quantpreshuffle_enabled

* clang format

---------

Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
Cong Ma
2025-10-02 12:13:51 -06:00
committed by GitHub
parent a4ab33f539
commit 6fc28ab493
7 changed files with 109 additions and 15 deletions

View File

@@ -7,7 +7,9 @@ list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
# Typed Test Suite for GEMM Quantization
add_gtest_executable(test_tile_gemm_quant_typed test_gemm_quant_typed.cpp)
add_gtest_executable(test_tile_gemm_quant_typed
test_gemm_quant_typed.cpp
)
target_compile_options(test_tile_gemm_quant_typed PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping ck_tile quant gemm tests for current target")

View File

@@ -87,6 +87,7 @@ class TestCkTileGemmQuantBase : public ::testing::Test
QuantType,
ALayout,
BLayout,
GemmConfig::TransposeC,
DoubleSmemBuffer>;
// Let the derived class create the appropriate pipeline and epilogue

View File

@@ -41,6 +41,22 @@ struct GemmConfigBase
static constexpr ck_tile::index_t K_Warp_Tile = 32;
};
struct GemmConfigPreshuffleQuant : public GemmConfigBase
{
static constexpr bool PreshuffleQuant = true;
};
struct GemmConfigTransposeC : public GemmConfigBase
{
static constexpr bool TransposeC = true;
};
struct GemmConfigPreshuffleQuantTransposeC : public GemmConfigBase
{
static constexpr bool PreshuffleQuant = true;
static constexpr bool TransposeC = true;
};
struct GemmConfigPreshuffleB
{
static constexpr bool kPadM = false;
@@ -100,6 +116,24 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
void SetUpQuantTypeSpecific() {}
void TearDownQuantTypeSpecific() {}
template <typename T>
auto shuffle_aq(const ck_tile::HostTensor<T>* t, int block_aq_k)
{
if(t->get_lengths().size() != 2)
{
throw std::runtime_error("Host tensor is not rank 2 tensor.");
}
int m_ = t->get_lengths()[0];
int aqk_ = t->get_lengths()[1];
if(aqk_ % block_aq_k != 0)
{
throw std::runtime_error("shuffle_aq needs a aqk of multiple times of block_aq_k.");
}
ck_tile::HostTensor<T> t_view({m_, aqk_ / block_aq_k, block_aq_k});
std::copy(t->begin(), t->end(), t_view.begin());
return ck_tile::reference_permute(t_view, {1, 0, 2});
}
// AQuant-specific data generation
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
{
@@ -150,7 +184,17 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
{
a_m_k_dev_buf.ToDevice(a_m_k.data());
}
aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
// aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
if constexpr(Base::GemmConfig::PreshuffleQuant)
{
ck_tile::HostTensor<QDataType> aq_shuffle_host =
shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / QuantGroupSize);
aq_m_aqk_dev_buf.ToDevice(aq_shuffle_host.data());
}
else
{
aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
}
b_k_n_dev_buf.ToDevice(b_k_n.data());
// Create args for kernel execution
@@ -245,7 +289,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr bool transpose_c = false;
constexpr bool transpose_c = CodegenGemmTraits::TransposeC;
using PipelineProblem =
ck_tile::GemmAQuantPipelineProblem<ADataType,
@@ -701,7 +745,7 @@ class TestCkTileGemmRowColQuant
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr bool transpose_c = false;
constexpr bool transpose_c = CodegenGemmTraits::TransposeC;
using PipelineProblem = ck_tile::GemmRowColTensorQuantPipelineProblem<
ADataType,
@@ -916,7 +960,7 @@ class TestCkTileGemmTensorQuant
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr bool transpose_c = false;
constexpr bool transpose_c = CodegenGemmTraits::TransposeC;
using PipelineProblem = ck_tile::GemmRowColTensorQuantPipelineProblem<
ADataType,

View File

@@ -25,10 +25,29 @@ using GroupSize = std::integral_constant<unsigned int, 128>;
// Type combinations for each quantization type
// clang-format off
using AQuantTypes = ::testing::Types<
// PreshuffleQuant = false && TransposeC = false
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
// PreshuffleQuant = false && TransposeC = true
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
// PreshuffleQuant = true && TransposeC = false
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPreshuffleQuant, GroupSize>,
// PreshuffleQuant = true && TransposeC = true
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPreshuffleQuantTransposeC, GroupSize>
>;
// clang-format on