Weight Preshuffle Block Scale gemm support (#2877)

* initial commit

* remove extra files

* fixing errors

* updated ReadMe file for mapping of diff quants with diff configs

* addressing review comments

* addressing review comments

* Resolved merge conflicts

* [CK TILE GEMM] Replace get_preshuffle_or with is_quantpreshuffle_enabled

The get_preshuffle_or was not working as expected, which led to incorrect behavior
in the quantization preshuffle process. This change replaces it with the more reliable
is_quantpreshuffle_enabled function to properly determine when preshuffle should be applied.

* initial commit

* debugging

* working fp8 for init constant

* fp8 working with all inits

* updated block level code with comments

* changing the loop iter

* debugging

* debugging

* debugging

* code fix

* code clean up

* clang formatted

* Add comment

* code cleanup

* clang formatted

* merge conflicts fixes

* applying the latest int4 changes to the piepline

* fixing test code for updated traits

* Adding gtest

* review comments addressed

* addressing review comments

* remove c++20 code

* added flush cache changes

---------

Co-authored-by: Cong Ma <congma13@amd.com>
Co-authored-by: root <root@banff-cyxtera-s73-2.ctr.dcgpu>

[ROCm/composable_kernel commit: 81458a6681]
This commit is contained in:
Khushbu Agarwal
2025-09-29 12:46:37 -07:00
committed by GitHub
parent 47b8632296
commit 7c20b1f690
17 changed files with 1129 additions and 53 deletions

View File

@@ -53,6 +53,9 @@ class TestCkTileGemmQuantBase : public ::testing::Test
static constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
static constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
static constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
static constexpr bool PreshuffleQuant = GemmConfig::PreshuffleQuant;
static constexpr bool PreshuffleB = GemmConfig::PreshuffleB;
static constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer;
public:
void SetUp() override { static_cast<Derived*>(this)->SetUpQuantTypeSpecific(); }
@@ -62,10 +65,9 @@ class TestCkTileGemmQuantBase : public ::testing::Test
// Common test execution logic
void invoke_quant_gemm(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr bool kPreshuffle = false;
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
@@ -77,11 +79,15 @@ class TestCkTileGemmQuantBase : public ::testing::Test
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
kPadN,
kPadK,
kPreshuffle,
PreshuffleQuant,
PreshuffleB,
ALayout,
BLayout,
CLayout,
QuantType>;
QuantType,
ALayout,
BLayout,
DoubleSmemBuffer>;
// Let the derived class create the appropriate pipeline and epilogue
static_cast<Derived*>(this)
@@ -125,6 +131,19 @@ class TestCkTileGemmQuantBase : public ::testing::Test
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = N_Warp_Tile == 32 ? 2 : 4;
ck_tile::HostTensor<T> t_view(
{n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
};
// Define generic QuantTypeTraits template (will be specialized)

View File

@@ -24,6 +24,7 @@ struct GemmConfigBase
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool PreshuffleQuant = false;
static constexpr bool PreshuffleB = false;
static constexpr bool DoubleSmemBuffer = false;
// Default GEMM tile sizes for tests
@@ -40,6 +41,41 @@ struct GemmConfigBase
static constexpr ck_tile::index_t K_Warp_Tile = 32;
};
struct GemmConfigPreshuffleB
{
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool PreshuffleQuant = false;
static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
// Default GEMM tile sizes for tests
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 64;
};
template <typename Tuple>
class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmAQuant<Tuple>>
{
@@ -288,6 +324,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
static constexpr auto QuantType = Base::QuantType;
static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize;
static constexpr auto PreshuffleB = Base::PreshuffleB;
protected:
void SetUpQuantTypeSpecific() {}
@@ -324,16 +361,23 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
// Copy to device
a_m_k_dev_buf.ToDevice(a_m_k.data());
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<BDataType> temp = b_k_n;
ck_tile::permute_vectors_i4x4_b(temp);
b_k_n_dev_buf.ToDevice(temp.data());
if constexpr(PreshuffleB)
{
b_k_n_dev = this->shuffle_b(b_k_n);
}
ck_tile::permute_vectors_i4x4_b(b_k_n_dev);
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
}
else
{
b_k_n_dev_buf.ToDevice(b_k_n.data());
if constexpr(PreshuffleB)
{
b_k_n_dev = this->shuffle_b(b_k_n);
}
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
}
bq_bqk_n_dev_buf.ToDevice(bq_bqk_n.data());
@@ -419,7 +463,10 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
CodegenGemmTraits,
ComputeDataType>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
using BaseGemmPipeline = std::conditional_t<
PreshuffleB == false,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>>;
const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
@@ -443,7 +490,11 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>;
using GemmPipeline =
std::conditional_t<PreshuffleB == false,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
@@ -486,6 +537,11 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
}
};
template <typename Tuple>
class TestCkTileGemmPreshuffleBBQuant : public TestCkTileGemmBQuant<Tuple>
{
};
// RowColQuant-specific test fixture
template <typename Tuple>
class TestCkTileGemmRowColQuant

View File

@@ -41,6 +41,14 @@ using BQuantTypes = ::testing::Types<
>;
// clang-format on
// clang-format off
using BPreshuffleBQuantTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleB, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleB, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleB, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleB, GroupSize>
>;
// clang-format off
using RowColQuantTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, RowColQuant, GemmConfigBase, GroupSize>,
@@ -58,6 +66,7 @@ using TensorQuantTypes = ::testing::Types<
// Test suites for each quantization type
TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantTypes);
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantTypes);
TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleBQuantTypes);
TYPED_TEST_SUITE(TestCkTileGemmRowColQuant, RowColQuantTypes);
TYPED_TEST_SUITE(TestCkTileGemmTensorQuant, TensorQuantTypes);

View File

@@ -15,6 +15,11 @@ TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
this->run_test_with_validation(1024, 1024, 1024);
}
// BQuant tests
TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}
// RowColQuant tests
TYPED_TEST(TestCkTileGemmRowColQuant, RowColQuantTest)
{