mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
Make CK TILE GEMM Aquant support block tile 128x128x128 (#3325)
* [CK TILE GEMM Quant] Rename GemmConfigBQuantPrefill to GemmConfigQuantPrefill in examples * [CK TILE GEMM Quant] update tile distribution of aquant * [CK TILE GEMM Quant] update aquant register offset calculation * [CK TILE GEMM Quant] Reimplement aquant register offset calculation * [CK TILE GEMM Quant] Add more unit tests of Aquant - Test M128xN128xK128 * [CK TILE GEMM Quant] Add more comments to Gemm Aquant
This commit is contained in:
@@ -53,6 +53,13 @@ struct GemmConfigBase
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<false>();
|
||||
};
|
||||
|
||||
struct GemmConfigPrefill : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
};
|
||||
|
||||
struct GemmConfigPreshuffleQuant : public GemmConfigBase
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
|
||||
@@ -39,6 +39,12 @@ using AQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
|
||||
// PreshuffleQuant = false && TransposeC = false && Prefill
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
|
||||
|
||||
// PreshuffleQuant = false && TransposeC = true
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
|
||||
|
||||
Reference in New Issue
Block a user