mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_Tile] Adding support for preshuffleQuant in AB quant Block Scale Gemm (#3629)
* initial commit * preshuffleQuant support for ABQuant * fix mxfp4 to use correct QuantGroupSize * addressing review comments and seperated Preshufflequant for A and B * updated grouped gemm example for updated traits definition * fix for CI failure * updated grouped_gemm_abquant test for updated traits definition * updated grouped_gemm_abquant test for updated traits definition
This commit is contained in:
@@ -45,7 +45,22 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
|
||||
target_compile_options(test_tile_gemm_quant_aquant_base_ccr PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# ABQuant tests
|
||||
add_gtest_executable(test_tile_gemm_quant_aquant_prefill
|
||||
test_gemm_quant_aquant_prefill.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_aquant_prefill PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_aquant_transpose_c
|
||||
test_gemm_quant_aquant_transpose_c.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_aquant_transpose_c PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_aquant_preshuffle
|
||||
test_gemm_quant_aquant_preshuffle.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_aquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# ABQuant tests split into 4 files
|
||||
add_gtest_executable(test_tile_gemm_quant_abquant_base
|
||||
test_gemm_quant_abquant_base.cpp
|
||||
)
|
||||
@@ -61,21 +76,10 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_abquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# AQuant tests
|
||||
add_gtest_executable(test_tile_gemm_quant_aquant_prefill
|
||||
test_gemm_quant_aquant_prefill.cpp
|
||||
add_gtest_executable(test_tile_gemm_quant_abquant_preshuffleQuant
|
||||
test_gemm_quant_abquant_preshuffleQuant.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_aquant_prefill PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_aquant_transpose_c
|
||||
test_gemm_quant_aquant_transpose_c.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_aquant_transpose_c PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_tile_gemm_quant_aquant_preshuffle
|
||||
test_gemm_quant_aquant_preshuffle.cpp
|
||||
)
|
||||
target_compile_options(test_tile_gemm_quant_aquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_tile_gemm_quant_abquant_preshuffleQuant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
# BQuant tests (without PreshuffleB) - split into 6 files
|
||||
add_gtest_executable(test_tile_gemm_quant_bquant_1d_128
|
||||
@@ -188,6 +192,11 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
test_tile_gemm_quant_aquant_prefill
|
||||
test_tile_gemm_quant_aquant_transpose_c
|
||||
test_tile_gemm_quant_aquant_preshuffle
|
||||
# ABQuant tests
|
||||
test_tile_gemm_quant_abquant_base
|
||||
test_tile_gemm_quant_abquant_padding
|
||||
test_tile_gemm_quant_abquant_preshuffle
|
||||
test_tile_gemm_quant_abquant_preshuffleQuant
|
||||
# BQuant tests
|
||||
test_tile_gemm_quant_bquant_1d_128
|
||||
test_tile_gemm_quant_bquant_1d_64
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using ABQuantGrouped =
|
||||
std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
|
||||
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
|
||||
|
||||
// 2d block sizes for BQuant
|
||||
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
|
||||
|
||||
// Type combinations for ABQuant tests
|
||||
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
|
||||
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
|
||||
// clang-format off
|
||||
using ABQuantPreshuffleQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize, GroupSize, ColumnMajor>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleQuantPrefill, GroupSize, GroupSize2D128N, ColumnMajor>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suite for ABQuant
|
||||
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantPreshuffleQuantTypes);
|
||||
|
||||
// AQuant tests
|
||||
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -75,7 +75,8 @@ class TestCkTileGemmQuantBase : public ::testing::Test
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
|
||||
static constexpr bool PreshuffleQuant = GemmConfig::PreshuffleQuant;
|
||||
static constexpr bool APreshuffleQuant = GemmConfig::APreshuffleQuant;
|
||||
static constexpr bool BPreshuffleQuant = GemmConfig::BPreshuffleQuant;
|
||||
static constexpr bool PreshuffleB = GemmConfig::PreshuffleB;
|
||||
static constexpr bool TiledMMAPermuteN = GemmConfig::TiledMMAPermuteN;
|
||||
static constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer;
|
||||
@@ -111,7 +112,8 @@ class TestCkTileGemmQuantBase : public ::testing::Test
|
||||
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
PreshuffleQuant,
|
||||
APreshuffleQuant,
|
||||
BPreshuffleQuant,
|
||||
PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
|
||||
@@ -34,7 +34,8 @@ struct GemmConfigBase
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool PreshuffleQuant = false;
|
||||
static constexpr bool APreshuffleQuant = false;
|
||||
static constexpr bool BPreshuffleQuant = false;
|
||||
static constexpr bool PreshuffleB = false;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
@@ -110,7 +111,7 @@ struct GemmConfigMxFp4 : public GemmConfigBase
|
||||
|
||||
struct GemmConfigPreshuffleQuant : public GemmConfigBase
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
static constexpr bool APreshuffleQuant = true;
|
||||
};
|
||||
|
||||
struct GemmConfigTransposeC : public GemmConfigBase
|
||||
@@ -120,8 +121,8 @@ struct GemmConfigTransposeC : public GemmConfigBase
|
||||
|
||||
struct GemmConfigPreshuffleQuantTransposeC : public GemmConfigBase
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
static constexpr bool TransposeC = true;
|
||||
static constexpr bool APreshuffleQuant = true;
|
||||
static constexpr bool TransposeC = true;
|
||||
};
|
||||
|
||||
struct GemmConfigPadding : public GemmConfigBase
|
||||
@@ -138,7 +139,7 @@ struct GemmConfigPreshuffleBDecode : public GemmConfigDecode
|
||||
|
||||
struct GemmConfigPreshuffleQuantDecode : public GemmConfigDecode
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
static constexpr bool BPreshuffleQuant = true;
|
||||
};
|
||||
|
||||
struct GemmConfigPreshuffleBPrefill : public GemmConfigPrefill
|
||||
@@ -149,7 +150,7 @@ struct GemmConfigPreshuffleBPrefill : public GemmConfigPrefill
|
||||
|
||||
struct GemmConfigPreshuffleQuantPrefill : public GemmConfigPrefill
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
static constexpr bool BPreshuffleQuant = true;
|
||||
};
|
||||
|
||||
struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBPrefill
|
||||
@@ -160,7 +161,7 @@ struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBP
|
||||
|
||||
struct GemmConfigPreshuffleBPreshuffleQuantDecode : public GemmConfigPreshuffleBDecode
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
static constexpr bool BPreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
@@ -244,7 +245,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
}
|
||||
// aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
|
||||
if constexpr(Base::GemmConfig::PreshuffleQuant)
|
||||
if constexpr(Base::GemmConfig::APreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<QDataType> aq_shuffle_host =
|
||||
ck_tile::shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / QuantGroupSize::kK);
|
||||
@@ -481,7 +482,7 @@ class TestCkTileGemmAQuantMem
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
}
|
||||
// aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
|
||||
if constexpr(Base::GemmConfig::PreshuffleQuant)
|
||||
if constexpr(Base::GemmConfig::APreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<QDataType> aq_shuffle_host =
|
||||
ck_tile::shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / QuantGroupSize::kK);
|
||||
@@ -727,7 +728,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
ck_tile::bq_permuteN<GemmConfig>(bq_bqk_bqn, QuantGroupSize::kN);
|
||||
bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data());
|
||||
}
|
||||
else if constexpr(GemmConfig::PreshuffleQuant)
|
||||
else if constexpr(GemmConfig::BPreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<QDataType> bq_shuffle_host =
|
||||
ck_tile::shuffle_bq(&bq_bqk_bqn, GemmConfig::K_Tile / QuantGroupSize::kK);
|
||||
@@ -1024,7 +1025,7 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
|
||||
|
||||
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
|
||||
|
||||
if constexpr(Base::GemmConfig::PreshuffleQuant)
|
||||
if constexpr(Base::GemmConfig::APreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<QDataType> aq_shuffle_host =
|
||||
ck_tile::shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / AQuantGroupSize::kK);
|
||||
@@ -1041,7 +1042,7 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
|
||||
ck_tile::bq_permuteN<GemmConfig>(bq_bqk_bqn, BQuantGroupSize::kN);
|
||||
bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data());
|
||||
}
|
||||
else if constexpr(GemmConfig::PreshuffleQuant)
|
||||
else if constexpr(GemmConfig::BPreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<QDataType> bq_shuffle_host =
|
||||
ck_tile::shuffle_bq(&bq_bqk_bqn, GemmConfig::K_Tile / BQuantGroupSize::kK);
|
||||
|
||||
@@ -117,6 +117,7 @@ class TestCkTileGroupedGemmABQuant : public ::testing::Test
|
||||
Config::kPadN,
|
||||
Config::kPadK,
|
||||
false,
|
||||
false,
|
||||
Config::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
@@ -241,6 +242,7 @@ class TestCkTileGroupedGemmABQuant : public ::testing::Test
|
||||
Config::kPadN,
|
||||
Config::kPadK,
|
||||
false,
|
||||
false,
|
||||
Config::PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
|
||||
@@ -112,6 +112,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
GroupedGemKernelParam::kPadN,
|
||||
GroupedGemKernelParam::kPadK,
|
||||
false,
|
||||
false,
|
||||
PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
@@ -289,6 +290,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
GroupedGemKernelParam::kPadN,
|
||||
GroupedGemKernelParam::kPadK,
|
||||
false,
|
||||
false,
|
||||
PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
|
||||
Reference in New Issue
Block a user