[CK_Tile] Adding support for preshuffleQuant in AB quant Block Scale Gemm (#3629)

* initial commit

* preshuffleQuant support for ABQuant

* fix mxfp4 to use correct QuantGroupSize

* addressing review comments and seperated Preshufflequant for A and B

* updated grouped gemm example for updated traits definition

* fix for CI failure

* updated grouped_gemm_abquant test for updated traits definition

* updated grouped_gemm_abquant test for updated traits definition
This commit is contained in:
Khushbu Agarwal
2026-01-28 19:45:09 -08:00
committed by GitHub
parent e3556fed04
commit 9b168082b7
33 changed files with 490 additions and 367 deletions

View File

@@ -45,7 +45,22 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
target_compile_options(test_tile_gemm_quant_aquant_base_ccr PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
# ABQuant tests
add_gtest_executable(test_tile_gemm_quant_aquant_prefill
test_gemm_quant_aquant_prefill.cpp
)
target_compile_options(test_tile_gemm_quant_aquant_prefill PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_aquant_transpose_c
test_gemm_quant_aquant_transpose_c.cpp
)
target_compile_options(test_tile_gemm_quant_aquant_transpose_c PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_aquant_preshuffle
test_gemm_quant_aquant_preshuffle.cpp
)
target_compile_options(test_tile_gemm_quant_aquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
# ABQuant tests split into 4 files
add_gtest_executable(test_tile_gemm_quant_abquant_base
test_gemm_quant_abquant_base.cpp
)
@@ -61,21 +76,10 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
)
target_compile_options(test_tile_gemm_quant_abquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
# AQuant tests
add_gtest_executable(test_tile_gemm_quant_aquant_prefill
test_gemm_quant_aquant_prefill.cpp
add_gtest_executable(test_tile_gemm_quant_abquant_preshuffleQuant
test_gemm_quant_abquant_preshuffleQuant.cpp
)
target_compile_options(test_tile_gemm_quant_aquant_prefill PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_aquant_transpose_c
test_gemm_quant_aquant_transpose_c.cpp
)
target_compile_options(test_tile_gemm_quant_aquant_transpose_c PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_tile_gemm_quant_aquant_preshuffle
test_gemm_quant_aquant_preshuffle.cpp
)
target_compile_options(test_tile_gemm_quant_aquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
target_compile_options(test_tile_gemm_quant_abquant_preshuffleQuant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
# BQuant tests (without PreshuffleB) - split into 6 files
add_gtest_executable(test_tile_gemm_quant_bquant_1d_128
@@ -188,6 +192,11 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
test_tile_gemm_quant_aquant_prefill
test_tile_gemm_quant_aquant_transpose_c
test_tile_gemm_quant_aquant_preshuffle
# ABQuant tests
test_tile_gemm_quant_abquant_base
test_tile_gemm_quant_abquant_padding
test_tile_gemm_quant_abquant_preshuffle
test_tile_gemm_quant_abquant_preshuffleQuant
# BQuant tests
test_tile_gemm_quant_bquant_1d_128
test_tile_gemm_quant_bquant_1d_64

View File

@@ -0,0 +1,43 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using ABQuantGrouped =
std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// 2d block sizes for BQuant
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
// Type combinations for ABQuant tests
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
// clang-format off
using ABQuantPreshuffleQuantTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize, GroupSize, ColumnMajor>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize, GroupSize2D128N, ColumnMajor>
>;
// clang-format on
// Test suite for ABQuant
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantPreshuffleQuantTypes);
// AQuant tests
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -75,7 +75,8 @@ class TestCkTileGemmQuantBase : public ::testing::Test
static constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
static constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
static constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
static constexpr bool PreshuffleQuant = GemmConfig::PreshuffleQuant;
static constexpr bool APreshuffleQuant = GemmConfig::APreshuffleQuant;
static constexpr bool BPreshuffleQuant = GemmConfig::BPreshuffleQuant;
static constexpr bool PreshuffleB = GemmConfig::PreshuffleB;
static constexpr bool TiledMMAPermuteN = GemmConfig::TiledMMAPermuteN;
static constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer;
@@ -111,7 +112,8 @@ class TestCkTileGemmQuantBase : public ::testing::Test
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
kPadN,
kPadK,
PreshuffleQuant,
APreshuffleQuant,
BPreshuffleQuant,
PreshuffleB,
ALayout,
BLayout,

View File

@@ -34,7 +34,8 @@ struct GemmConfigBase
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool PreshuffleQuant = false;
static constexpr bool APreshuffleQuant = false;
static constexpr bool BPreshuffleQuant = false;
static constexpr bool PreshuffleB = false;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool TiledMMAPermuteN = false;
@@ -110,7 +111,7 @@ struct GemmConfigMxFp4 : public GemmConfigBase
struct GemmConfigPreshuffleQuant : public GemmConfigBase
{
static constexpr bool PreshuffleQuant = true;
static constexpr bool APreshuffleQuant = true;
};
struct GemmConfigTransposeC : public GemmConfigBase
@@ -120,8 +121,8 @@ struct GemmConfigTransposeC : public GemmConfigBase
struct GemmConfigPreshuffleQuantTransposeC : public GemmConfigBase
{
static constexpr bool PreshuffleQuant = true;
static constexpr bool TransposeC = true;
static constexpr bool APreshuffleQuant = true;
static constexpr bool TransposeC = true;
};
struct GemmConfigPadding : public GemmConfigBase
@@ -138,7 +139,7 @@ struct GemmConfigPreshuffleBDecode : public GemmConfigDecode
struct GemmConfigPreshuffleQuantDecode : public GemmConfigDecode
{
static constexpr bool PreshuffleQuant = true;
static constexpr bool BPreshuffleQuant = true;
};
struct GemmConfigPreshuffleBPrefill : public GemmConfigPrefill
@@ -149,7 +150,7 @@ struct GemmConfigPreshuffleBPrefill : public GemmConfigPrefill
struct GemmConfigPreshuffleQuantPrefill : public GemmConfigPrefill
{
static constexpr bool PreshuffleQuant = true;
static constexpr bool BPreshuffleQuant = true;
};
struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBPrefill
@@ -160,7 +161,7 @@ struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBP
struct GemmConfigPreshuffleBPreshuffleQuantDecode : public GemmConfigPreshuffleBDecode
{
static constexpr bool PreshuffleQuant = true;
static constexpr bool BPreshuffleQuant = true;
};
template <typename Tuple>
@@ -244,7 +245,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
a_m_k_dev_buf.ToDevice(a_m_k.data());
}
// aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
if constexpr(Base::GemmConfig::PreshuffleQuant)
if constexpr(Base::GemmConfig::APreshuffleQuant)
{
ck_tile::HostTensor<QDataType> aq_shuffle_host =
ck_tile::shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / QuantGroupSize::kK);
@@ -481,7 +482,7 @@ class TestCkTileGemmAQuantMem
a_m_k_dev_buf.ToDevice(a_m_k.data());
}
// aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
if constexpr(Base::GemmConfig::PreshuffleQuant)
if constexpr(Base::GemmConfig::APreshuffleQuant)
{
ck_tile::HostTensor<QDataType> aq_shuffle_host =
ck_tile::shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / QuantGroupSize::kK);
@@ -727,7 +728,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
ck_tile::bq_permuteN<GemmConfig>(bq_bqk_bqn, QuantGroupSize::kN);
bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data());
}
else if constexpr(GemmConfig::PreshuffleQuant)
else if constexpr(GemmConfig::BPreshuffleQuant)
{
ck_tile::HostTensor<QDataType> bq_shuffle_host =
ck_tile::shuffle_bq(&bq_bqk_bqn, GemmConfig::K_Tile / QuantGroupSize::kK);
@@ -1024,7 +1025,7 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
if constexpr(Base::GemmConfig::PreshuffleQuant)
if constexpr(Base::GemmConfig::APreshuffleQuant)
{
ck_tile::HostTensor<QDataType> aq_shuffle_host =
ck_tile::shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / AQuantGroupSize::kK);
@@ -1041,7 +1042,7 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
ck_tile::bq_permuteN<GemmConfig>(bq_bqk_bqn, BQuantGroupSize::kN);
bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data());
}
else if constexpr(GemmConfig::PreshuffleQuant)
else if constexpr(GemmConfig::BPreshuffleQuant)
{
ck_tile::HostTensor<QDataType> bq_shuffle_host =
ck_tile::shuffle_bq(&bq_bqk_bqn, GemmConfig::K_Tile / BQuantGroupSize::kK);

View File

@@ -117,6 +117,7 @@ class TestCkTileGroupedGemmABQuant : public ::testing::Test
Config::kPadN,
Config::kPadK,
false,
false,
Config::PreshuffleB,
ALayout,
BLayout,
@@ -241,6 +242,7 @@ class TestCkTileGroupedGemmABQuant : public ::testing::Test
Config::kPadN,
Config::kPadK,
false,
false,
Config::PreshuffleB,
ALayout,
BLayout,

View File

@@ -112,6 +112,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
GroupedGemKernelParam::kPadN,
GroupedGemKernelParam::kPadK,
false,
false,
PreshuffleB,
ALayout,
BLayout,
@@ -289,6 +290,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
GroupedGemKernelParam::kPadN,
GroupedGemKernelParam::kPadK,
false,
false,
PreshuffleB,
ALayout,
BLayout,