mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-26 08:00:13 +00:00
Merge commit '8111572785d3de98457940f2b5ca6fe9cf7603af' into develop
This commit is contained in:
@@ -113,6 +113,11 @@ struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBP
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
|
||||
};
|
||||
|
||||
struct GemmConfigPreshuffleBPreshuffleQuantDecode : public GemmConfigPreshuffleBDecode
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmAQuant<Tuple>>
|
||||
{
|
||||
@@ -436,7 +441,13 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
{
|
||||
printf("Preshuffle BQ with TiledMMAPermuteN \n");
|
||||
ck_tile::HostTensor<QDataType> bq_shuffle_host =
|
||||
ck_tile::shuffle_bq_permuteN<GemmConfig>(bq_bqk_bqn);
|
||||
ck_tile::bq_permuteN<GemmConfig>(bq_bqk_bqn);
|
||||
bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data());
|
||||
}
|
||||
else if constexpr(GemmConfig::PreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<QDataType> bq_shuffle_host =
|
||||
ck_tile::shuffle_bq(&bq_bqk_bqn, GemmConfig::K_Tile / QuantGroupSize::kK);
|
||||
bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
|
||||
@@ -111,7 +111,12 @@ using BPreshuffleBQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
|
||||
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPreshuffleQuantDecode, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPreshuffleQuantDecode, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPreshuffleQuantDecode, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPreshuffleQuantDecode, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
Reference in New Issue
Block a user