Make CK TILE GEMM Aquant support block tile 128x128x128 (#3325)

* [CK TILE GEMM Quant] Rename GemmConfigBQuantPrefill to GemmConfigQuantPrefill in examples

* [CK TILE GEMM Quant] update tile distribution of aquant

* [CK TILE GEMM Quant] update aquant register offset calculation

* [CK TILE GEMM Quant] Reimplement aquant register offset calculation

* [CK TILE GEMM Quant] Add more unit tests of Aquant

- Test M128xN128xK128

* [CK TILE GEMM Quant] Add more comments to Gemm Aquant
This commit is contained in:
Cong Ma
2025-12-01 16:04:37 -07:00
committed by GitHub
parent 7873f8fa13
commit 23fb253c4e
11 changed files with 58 additions and 46 deletions

View File

@@ -74,7 +74,7 @@ User need to select correct mapping of config for each quant mode:
|:--------|:-----:|:-----:|-------|
| For selecting AQuant | aquant | gemm_aquant_quantgrouped.cpp| GemmConfigQuantDecode |
| For selecting AQuant with Preshuffle quant | aquant | gemm_aquant_quantgrouped_preshufflequant.cpp | GemmConfigPreshuffleQuantDecode |
| For selecting BQuant | bquant | gemm_bquant_quantgrouped_<prec_type>.cpp| GemmConfigQuantDecode (or) GemmConfigBQuantPrefill |
| For selecting BQuant | bquant | gemm_bquant_quantgrouped_<prec_type>.cpp| GemmConfigQuantDecode (or) GemmConfigQuantPrefill |
| For selecting BQuant with Preshuffle quant | bquant | gemm_bquant_quantgrouped_preshufflequant.cpp| GemmConfigPreshuffleQuantDecode (or) GemmConfigPreshuffleBQuantPrefill |
| For selecting PreShuffle B with BQuant | bquant | gemm_bquant_quantgrouped_preshuffleb.cpp| GemmConfigPreshuffleB_BQuant_Decode (or) GemmConfigPreshuffleB_BQuant_Prefill
| For selecting PreShuffle B with preshuffle BQuant | bquant | gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp |GemmConfigPreshuffleB_PreshuffleBQuant_Decode (or) GemmConfigPreshuffleB_PreshuffleBQuant_Prefill

View File

@@ -6,6 +6,10 @@
template <typename T>
using GemmConfig = GemmConfigQuantDecode<T>;
// GemmConfigQuantPrefill is also supported for aquant grouped quantization
// template <typename T>
// using GemmConfig = GemmConfigQuantPrefill<T>;
void aquant_quantgrouped_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{

View File

@@ -4,7 +4,7 @@
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigBQuantPrefill<T>;
using GemmConfig = GemmConfigQuantPrefill<T>;
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \

View File

@@ -4,7 +4,7 @@
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigBQuantPrefill<T>;
using GemmConfig = GemmConfigQuantPrefill<T>;
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \

View File

@@ -4,7 +4,7 @@
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigBQuantPrefill<T>;
using GemmConfig = GemmConfigQuantPrefill<T>;
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \

View File

@@ -4,7 +4,7 @@
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigBQuantPrefill<T>;
using GemmConfig = GemmConfigQuantPrefill<T>;
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \

View File

@@ -221,7 +221,7 @@ struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill
};
template <typename PrecType>
struct GemmConfigBQuantPrefill : public GemmConfigBase
struct GemmConfigQuantPrefill : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -237,13 +237,13 @@ struct GemmConfigBQuantPrefill : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigBQuantPrefill<PrecType>
struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill<PrecType>
{
static constexpr bool PreshuffleQuant = true;
};
template <typename PrecType>
struct GemmConfigBQuantPrefill_Wmma : public GemmConfigBQuantPrefill<PrecType>
struct GemmConfigBQuantPrefill_Wmma : public GemmConfigQuantPrefill<PrecType>
{
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;

View File

@@ -373,8 +373,8 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
{
// Need to multiply aquant with accumulated C
//
// The accumulated C tile has the standard distribution. For example
// lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0],
// The accumulated C tile has the standard distribution. For example, a
// 32x32 C lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0],
// [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0],
// [26,0], [27,0].
//
@@ -388,35 +388,31 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
//
// These scales can be obtained using __builtin_amdgcn_ds_bpermute.
// MIters per warp
constexpr index_t mIters_per_warp = get_warp_size() / WarpGemm::kM;
// Reg block offset based on mIter
constexpr index_t reg_block_offset =
((mIter / mIters_per_warp) * Traits::AQPerBlock);
constexpr index_t lane_base_offset =
(mIter % mIters_per_warp) * WarpGemm::kM;
// Scale tensor offset along K
constexpr index_t src_reg_offset = reg_block_offset + kQScale;
// Directly index into thread buffer corresponding to
// desired row coefficient
// Each thread stores AQPerBlock scale values per M iteration.
constexpr index_t reg_block_offset = mIter * Traits::AQPerBlock;
constexpr index_t src_reg_offset = reg_block_offset + kQScale;
auto& scale_reg = aq_block_tensor.get_thread_buffer()[src_reg_offset];
constexpr uint32_t kTileRows = (get_warp_size() == 64) ? 4 : 8;
;
constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows;
constexpr uint32_t reg_offset_for_row_data = c_row * WarpGemm::kCMLane;
// Multiply by 4 because output is stored in tiles of 4
// x CNLane
constexpr uint32_t row_base =
((reg_offset_for_row_data / kTiledCMsPerWarp) * kTiledCMsPerWarp) +
((reg_offset_for_row_data % kTiledCMsPerWarp) / WarpGemm::kCMLane);
// Divide M dimension of C Warp tile into groups of
// (WarpGemm::kCMLane * WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane)
// m_base_offset_of_c_row indicates which group the current c_row belongs
// to.
constexpr index_t m_base_offset_of_c_row =
(c_row / WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane) *
(WarpGemm::kCMLane * WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane);
// M offset of each thread within its group (see comment above)
index_t m_base_offset_of_lane =
(get_lane_id() / WarpGemm::kN *
WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane);
// M offset wrt. c_row in the subgroup of kCM1PerLane
constexpr index_t m_offset_of_c_row =
c_row & (WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane - 1);
// Lane index to source scale from
uint32_t src_lane_idx =
lane_base_offset + row_base + (__lane_id() / WarpGemm::kN * kTileRows);
m_base_offset_of_c_row + m_base_offset_of_lane + m_offset_of_c_row;
return exchange_quant_value_across_lanes(scale_reg, src_lane_idx);
}

View File

@@ -94,21 +94,20 @@ struct tile_distribution_encoding_pattern_aq : public tile_distribution_encoding
// # of elements per thread
constexpr index_t X = XPerTile;
constexpr index_t Y0 = 1;
constexpr index_t Y1 = MIterPerWarp ? MIterPerWarp : 1;
constexpr index_t Y2 = MWarps;
constexpr index_t Y3 = WarpGemm::kM;
static_assert(Y3 >= WarpGemm::kM,
constexpr index_t YR = 1;
constexpr index_t Y0 = MIterPerWarp ? MIterPerWarp : 1;
constexpr index_t Y1 = MWarps;
constexpr index_t Y2 = WarpGemm::kM;
static_assert(Y2 >= WarpGemm::kM,
"Scales for all rows must be available within the warp.");
static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile,
"Y0, Y1, Y2, Y3 must cover the blocktile along Y.");
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y.");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWarps>,
tuple<sequence<Y0, Y1, Y2, Y3>, sequence<X>>,
tuple<sequence<1, 0>, sequence<1, 1>>,
tuple<sequence<2, 0>, sequence<0, 3>>,
tile_distribution_encoding<sequence<NWarps, YR>,
tuple<sequence<Y0, Y1, Y2>, sequence<X>>,
tuple<sequence<1, 0>, sequence<0, 1>>,
tuple<sequence<1, 0>, sequence<1, 2>>,
sequence<1, 2>,
sequence<1, 0>>{});
sequence<0, 0>>{});
}
}
};

View File

@@ -53,6 +53,13 @@ struct GemmConfigBase
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<false>();
};
struct GemmConfigPrefill : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
};
struct GemmConfigPreshuffleQuant : public GemmConfigBase
{
static constexpr bool PreshuffleQuant = true;

View File

@@ -39,6 +39,12 @@ using AQuantTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
// PreshuffleQuant = false && TransposeC = false && Prefill
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
// PreshuffleQuant = false && TransposeC = true
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,