mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Make CK TILE GEMM Aquant support block tile 128x128x128 (#3325)
* [CK TILE GEMM Quant] Rename GemmConfigBQuantPrefill to GemmConfigQuantPrefill in examples * [CK TILE GEMM Quant] update tile distribution of aquant * [CK TILE GEMM Quant] update aquant register offset calculation * [CK TILE GEMM Quant] Reimplement aquant register offset calculation * [CK TILE GEMM Quant] Add more unit tests of Aquant - Test M128xN128xK128 * [CK TILE GEMM Quant] Add more comments to Gemm Aquant
This commit is contained in:
@@ -74,7 +74,7 @@ User need to select correct mapping of config for each quant mode:
|
||||
|:--------|:-----:|:-----:|-------|
|
||||
| For selecting AQuant | aquant | gemm_aquant_quantgrouped.cpp| GemmConfigQuantDecode |
|
||||
| For selecting AQuant with Preshuffle quant | aquant | gemm_aquant_quantgrouped_preshufflequant.cpp | GemmConfigPreshuffleQuantDecode |
|
||||
| For selecting BQuant | bquant | gemm_bquant_quantgrouped_<prec_type>.cpp| GemmConfigQuantDecode (or) GemmConfigBQuantPrefill |
|
||||
| For selecting BQuant | bquant | gemm_bquant_quantgrouped_<prec_type>.cpp| GemmConfigQuantDecode (or) GemmConfigQuantPrefill |
|
||||
| For selecting BQuant with Preshuffle quant | bquant | gemm_bquant_quantgrouped_preshufflequant.cpp| GemmConfigPreshuffleQuantDecode (or) GemmConfigPreshuffleBQuantPrefill |
|
||||
| For selecting PreShuffle B with BQuant | bquant | gemm_bquant_quantgrouped_preshuffleb.cpp| GemmConfigPreshuffleB_BQuant_Decode (or) GemmConfigPreshuffleB_BQuant_Prefill
|
||||
| For selecting PreShuffle B with preshuffle BQuant | bquant | gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp |GemmConfigPreshuffleB_PreshuffleBQuant_Decode (or) GemmConfigPreshuffleB_PreshuffleBQuant_Prefill
|
||||
|
||||
@@ -6,6 +6,10 @@
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigQuantDecode<T>;
|
||||
|
||||
// GemmConfigQuantPrefill is also supported for aquant grouped quantization
|
||||
// template <typename T>
|
||||
// using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
|
||||
void aquant_quantgrouped_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigBQuantPrefill<T>;
|
||||
using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigBQuantPrefill<T>;
|
||||
using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigBQuantPrefill<T>;
|
||||
using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigBQuantPrefill<T>;
|
||||
using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
|
||||
@@ -221,7 +221,7 @@ struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigBQuantPrefill : public GemmConfigBase
|
||||
struct GemmConfigQuantPrefill : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
@@ -237,13 +237,13 @@ struct GemmConfigBQuantPrefill : public GemmConfigBase
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigBQuantPrefill<PrecType>
|
||||
struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill<PrecType>
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigBQuantPrefill_Wmma : public GemmConfigBQuantPrefill<PrecType>
|
||||
struct GemmConfigBQuantPrefill_Wmma : public GemmConfigQuantPrefill<PrecType>
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
|
||||
@@ -373,8 +373,8 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
{
|
||||
// Need to multiply aquant with accumulated C
|
||||
//
|
||||
// The accumulated C tile has the standard distribution. For example
|
||||
// lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0],
|
||||
// The accumulated C tile has the standard distribution. For example, a
|
||||
// 32x32 C lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0],
|
||||
// [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0],
|
||||
// [26,0], [27,0].
|
||||
//
|
||||
@@ -388,35 +388,31 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
//
|
||||
// These scales can be obtained using __builtin_amdgcn_ds_bpermute.
|
||||
|
||||
// MIters per warp
|
||||
constexpr index_t mIters_per_warp = get_warp_size() / WarpGemm::kM;
|
||||
|
||||
// Reg block offset based on mIter
|
||||
constexpr index_t reg_block_offset =
|
||||
((mIter / mIters_per_warp) * Traits::AQPerBlock);
|
||||
|
||||
constexpr index_t lane_base_offset =
|
||||
(mIter % mIters_per_warp) * WarpGemm::kM;
|
||||
|
||||
// Scale tensor offset along K
|
||||
constexpr index_t src_reg_offset = reg_block_offset + kQScale;
|
||||
// Directly index into thread buffer corresponding to
|
||||
// desired row coefficient
|
||||
// Each thread stores AQPerBlock scale values per M iteration.
|
||||
constexpr index_t reg_block_offset = mIter * Traits::AQPerBlock;
|
||||
constexpr index_t src_reg_offset = reg_block_offset + kQScale;
|
||||
auto& scale_reg = aq_block_tensor.get_thread_buffer()[src_reg_offset];
|
||||
|
||||
constexpr uint32_t kTileRows = (get_warp_size() == 64) ? 4 : 8;
|
||||
;
|
||||
constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows;
|
||||
constexpr uint32_t reg_offset_for_row_data = c_row * WarpGemm::kCMLane;
|
||||
// Multiply by 4 because output is stored in tiles of 4
|
||||
// x CNLane
|
||||
constexpr uint32_t row_base =
|
||||
((reg_offset_for_row_data / kTiledCMsPerWarp) * kTiledCMsPerWarp) +
|
||||
((reg_offset_for_row_data % kTiledCMsPerWarp) / WarpGemm::kCMLane);
|
||||
// Divide M dimension of C Warp tile into groups of
|
||||
// (WarpGemm::kCMLane * WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane)
|
||||
// m_base_offset_of_c_row indicates which group the current c_row belongs
|
||||
// to.
|
||||
constexpr index_t m_base_offset_of_c_row =
|
||||
(c_row / WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane) *
|
||||
(WarpGemm::kCMLane * WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane);
|
||||
|
||||
// M offset of each thread within its group (see comment above)
|
||||
index_t m_base_offset_of_lane =
|
||||
(get_lane_id() / WarpGemm::kN *
|
||||
WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane);
|
||||
|
||||
// M offset wrt. c_row in the subgroup of kCM1PerLane
|
||||
constexpr index_t m_offset_of_c_row =
|
||||
c_row & (WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane - 1);
|
||||
|
||||
// Lane index to source scale from
|
||||
uint32_t src_lane_idx =
|
||||
lane_base_offset + row_base + (__lane_id() / WarpGemm::kN * kTileRows);
|
||||
m_base_offset_of_c_row + m_base_offset_of_lane + m_offset_of_c_row;
|
||||
|
||||
return exchange_quant_value_across_lanes(scale_reg, src_lane_idx);
|
||||
}
|
||||
|
||||
@@ -94,21 +94,20 @@ struct tile_distribution_encoding_pattern_aq : public tile_distribution_encoding
|
||||
// # of elements per thread
|
||||
constexpr index_t X = XPerTile;
|
||||
|
||||
constexpr index_t Y0 = 1;
|
||||
constexpr index_t Y1 = MIterPerWarp ? MIterPerWarp : 1;
|
||||
constexpr index_t Y2 = MWarps;
|
||||
constexpr index_t Y3 = WarpGemm::kM;
|
||||
static_assert(Y3 >= WarpGemm::kM,
|
||||
constexpr index_t YR = 1;
|
||||
constexpr index_t Y0 = MIterPerWarp ? MIterPerWarp : 1;
|
||||
constexpr index_t Y1 = MWarps;
|
||||
constexpr index_t Y2 = WarpGemm::kM;
|
||||
static_assert(Y2 >= WarpGemm::kM,
|
||||
"Scales for all rows must be available within the warp.");
|
||||
static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile,
|
||||
"Y0, Y1, Y2, Y3 must cover the blocktile along Y.");
|
||||
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y.");
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<NWarps>,
|
||||
tuple<sequence<Y0, Y1, Y2, Y3>, sequence<X>>,
|
||||
tuple<sequence<1, 0>, sequence<1, 1>>,
|
||||
tuple<sequence<2, 0>, sequence<0, 3>>,
|
||||
tile_distribution_encoding<sequence<NWarps, YR>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X>>,
|
||||
tuple<sequence<1, 0>, sequence<0, 1>>,
|
||||
tuple<sequence<1, 0>, sequence<1, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 0>>{});
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -53,6 +53,13 @@ struct GemmConfigBase
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<false>();
|
||||
};
|
||||
|
||||
struct GemmConfigPrefill : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
};
|
||||
|
||||
struct GemmConfigPreshuffleQuant : public GemmConfigBase
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
|
||||
@@ -39,6 +39,12 @@ using AQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
|
||||
// PreshuffleQuant = false && TransposeC = false && Prefill
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
|
||||
|
||||
// PreshuffleQuant = false && TransposeC = true
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
|
||||
|
||||
Reference in New Issue
Block a user