mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
Preshuffle AQ matrix in block scale gemm (#2624)
* Preshuffle AQ matrix in block scale gemm * turns the output to fp16. Increase the repetition time. --------- Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
@@ -38,12 +38,9 @@ struct GemmAQuantPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Prob
|
||||
{
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
|
||||
using YPerTile = number<MPerBlock>;
|
||||
using XPerTile = number<KPerBlockAQ>;
|
||||
|
||||
auto aq_copy_dram_window =
|
||||
make_tile_window(aq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(YPerTile(), XPerTile()),
|
||||
aq_dram_block_window_tmp.get_window_lengths(),
|
||||
aq_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeAQDramTileDistribution<Problem>());
|
||||
return aq_copy_dram_window;
|
||||
|
||||
@@ -42,6 +42,7 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockAQ = KPerBlock / Problem::kQuantGroupSize;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
|
||||
constexpr bool Preshuffle = Problem::Traits::Preshuffle;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
@@ -52,14 +53,34 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
false>;
|
||||
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
using TileEncodingPattern = TileDistributionEncodingPatternAQ<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlockAQ,
|
||||
VecLoadSize>;
|
||||
if constexpr(Preshuffle)
|
||||
{
|
||||
using TileEncodingPattern =
|
||||
TileDistributionEncodingPatternAQ<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
MPerBlock / WarpGemm::kM,
|
||||
ck_tile::integer_least_multiple(
|
||||
WarpGemm::kM * KPerBlockAQ, get_warp_size()),
|
||||
KPerBlockAQ,
|
||||
VecLoadSize,
|
||||
Preshuffle>;
|
||||
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
}
|
||||
else
|
||||
{
|
||||
using TileEncodingPattern = TileDistributionEncodingPatternAQ<BlockGemmShape,
|
||||
WarpGemm,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlockAQ,
|
||||
KPerBlockAQ,
|
||||
VecLoadSize,
|
||||
Preshuffle>;
|
||||
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
@@ -134,6 +133,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr bool Preshuffle = Problem::Traits::Preshuffle;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
@@ -254,9 +254,6 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(!is_aq_col_major, "Aq must be row major (col major not supported yet)");
|
||||
static_assert(MPerBlock == AQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlockAQ == AQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
|
||||
"Aq block window has incorrect lengths for defined AqLayout!");
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
@@ -312,8 +309,11 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
||||
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
|
||||
// only row_major for AQ
|
||||
constexpr AQDramTileWindowStep aq_dram_tile_window_step =
|
||||
is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ);
|
||||
Preshuffle ? make_array(MPerBlock / BlockGemm::WarpGemm::kM, 0)
|
||||
: make_array(0, KPerBlockAQ);
|
||||
|
||||
// DRAM prefetch (global read 0)
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
@@ -50,10 +50,11 @@ template <typename BlockGemmShape,
|
||||
index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize>
|
||||
index_t KPerBlockAQ,
|
||||
index_t VecSize,
|
||||
bool Preshuffle>
|
||||
struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPattern
|
||||
{
|
||||
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
@@ -69,26 +70,46 @@ struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPatter
|
||||
// KWarps > 1 isn't supported
|
||||
static_assert(KWarps == 1);
|
||||
|
||||
// # of elements per thread
|
||||
static constexpr index_t X = XPerTile;
|
||||
|
||||
static constexpr index_t Y0 = 1;
|
||||
static constexpr index_t Y1 = MIterPerWarp ? MIterPerWarp : 1;
|
||||
static constexpr index_t Y2 = MWarps;
|
||||
static constexpr index_t Y3 = WarpGemm::kM;
|
||||
static_assert(Y3 >= WarpGemm::kM, "Scales for all rows must be available within the warp.");
|
||||
static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile,
|
||||
"Y0, Y1, Y2, Y3 must cover the blocktile along Y.");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<NWarps>,
|
||||
tuple<sequence<Y0, Y1, Y2, Y3>, sequence<X>>,
|
||||
tuple<sequence<1, 0>, sequence<1, 1>>,
|
||||
tuple<sequence<2, 0>, sequence<0, 3>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 0>>{});
|
||||
if constexpr(Preshuffle)
|
||||
{
|
||||
// # of elements per thread
|
||||
constexpr index_t X2 = KPerBlockAQ;
|
||||
constexpr index_t X1 = warp_size / X2;
|
||||
constexpr index_t X0 = XPerTile / warp_size;
|
||||
|
||||
constexpr index_t Y1 = MWarps;
|
||||
constexpr index_t Y0 = YPerTile / Y1;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<NWarps>,
|
||||
tuple<sequence<Y0, Y1>, sequence<X0, X1, X2>>,
|
||||
tuple<sequence<1, 0>, sequence<2, 2>>,
|
||||
tuple<sequence<1, 0>, sequence<1, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// # of elements per thread
|
||||
constexpr index_t X = XPerTile;
|
||||
|
||||
constexpr index_t Y0 = 1;
|
||||
constexpr index_t Y1 = MIterPerWarp ? MIterPerWarp : 1;
|
||||
constexpr index_t Y2 = MWarps;
|
||||
constexpr index_t Y3 = WarpGemm::kM;
|
||||
static_assert(Y3 >= WarpGemm::kM,
|
||||
"Scales for all rows must be available within the warp.");
|
||||
static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile,
|
||||
"Y0, Y1, Y2, Y3 must cover the blocktile along Y.");
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<NWarps>,
|
||||
tuple<sequence<Y0, Y1, Y2, Y3>, sequence<X>>,
|
||||
tuple<sequence<1, 0>, sequence<1, 1>>,
|
||||
tuple<sequence<2, 0>, sequence<0, 3>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 0>>{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ namespace ck_tile {
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
bool Preshuffle_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename CLayout_,
|
||||
@@ -29,6 +30,7 @@ struct TileGemmAQuantTraits
|
||||
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr index_t NumWaveGroups = 1;
|
||||
static constexpr bool Preshuffle = Preshuffle_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user