mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_Tile] Support for preshuffle weight(B) quant tensor for block scale gemm (#3165)
* formatted * formatted * formatting * formatting * formatting * [CK TILE GEMM] Refactor block_scale_gemm examples - Split cpp file to reduce building time - Support multiple GemmConfig * [CK TILE GEMM] Refactor block_scale_gemm examples - Update Readme * enable prefill shapes * [CK TILE GEMM] Refactor block_scale_gemm examples - Add support for rowcol and tensor GEMM operations * [CK TILE GEMM] Refactor block_scale_gemm examples - Update README * adding preshuffle quant as new parameter and its associated new files * remove debugging statements * adding test * enable preshuffle quant with permuteN * updating readme and correcponding gemmconfigs * updating cmake file * fixing CI failures for grouped quant gemm * addressing review comments * fixing CI issue * addressing reveiw comments * formatting * formatting * fixing aquant operator overlaoding * formatting --------- Co-authored-by: Cong Ma <congma13@amd.com> Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -113,6 +113,11 @@ struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBP
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
|
||||
};
|
||||
|
||||
struct GemmConfigPreshuffleBPreshuffleQuantDecode : public GemmConfigPreshuffleBDecode
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmAQuant<Tuple>>
|
||||
{
|
||||
@@ -436,7 +441,13 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
{
|
||||
printf("Preshuffle BQ with TiledMMAPermuteN \n");
|
||||
ck_tile::HostTensor<QDataType> bq_shuffle_host =
|
||||
ck_tile::shuffle_bq_permuteN<GemmConfig>(bq_bqk_bqn);
|
||||
ck_tile::bq_permuteN<GemmConfig>(bq_bqk_bqn);
|
||||
bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data());
|
||||
}
|
||||
else if constexpr(GemmConfig::PreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<QDataType> bq_shuffle_host =
|
||||
ck_tile::shuffle_bq(&bq_bqk_bqn, GemmConfig::K_Tile / QuantGroupSize::kK);
|
||||
bq_bqk_bqn_dev_buf.ToDevice(bq_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
|
||||
@@ -111,7 +111,12 @@ using BPreshuffleBQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPrefillTiledPermuteN, GroupSize>,
|
||||
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPreshuffleQuantDecode, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigPreshuffleBPreshuffleQuantDecode, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigPreshuffleBPreshuffleQuantDecode, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigPreshuffleBPreshuffleQuantDecode, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
Reference in New Issue
Block a user