adding test

This commit is contained in:
khuagarw
2025-11-12 01:30:38 +00:00
parent 075c36b5f9
commit 0f79fa5aed
2 changed files with 17 additions and 1 deletions

View File

@@ -103,6 +103,11 @@ struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBP
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
};
struct GemmConfigPreshuffleBPreshuffleQuantDecode : public GemmConfigPreshuffleBDecode
{
static constexpr bool PreshuffleQuant = true;
};
template <typename Tuple>
class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmAQuant<Tuple>>
{
@@ -429,6 +434,12 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
ck_tile::shuffle_bq_permuteN<GemmConfig>(bq_bqk_bqn);
bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data());
}
else if constexpr(GemmConfig::PreshuffleQuant)
{
ck_tile::HostTensor<QDataType> bq_shuffle_host =
ck_tile::shuffle_bq(&bq_bqk_bqn, GemmConfig::K_Tile / QuantGroupSize::kK);
bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data());
}
else
{
bq_bqk_bqn_dev_buf.ToDevice(bq_bqk_bqn.data());

View File

@@ -111,7 +111,12 @@ using BPreshuffleBQuantTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPreshuffleQuantDecode, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPreshuffleQuantDecode, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPreshuffleQuantDecode, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPreshuffleQuantDecode, GroupSize>
>;
// clang-format on