mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
update gemm_mx_pipeline_ag_bg_cr_v3
This commit is contained in:
@@ -115,13 +115,6 @@ struct GemmMXPipelineAgBgCrCompV3 : public BaseGemmMXPipelineAgBgCrCompV3<Proble
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t BlockScaleSize = Problem::kBlockScaleSize;
|
||||
static constexpr index_t KPerBlockScale = BlockGemmShape::kK / BlockScaleSize;
|
||||
|
||||
using MXdlPack = PipelineImplBase::MXdlPack;
|
||||
using NXdlPack = PipelineImplBase::NXdlPack;
|
||||
using KXdlPack = PipelineImplBase::KXdlPack;
|
||||
@@ -129,6 +122,14 @@ struct GemmMXPipelineAgBgCrCompV3 : public BaseGemmMXPipelineAgBgCrCompV3<Proble
|
||||
using APackedSize = PipelineImplBase::APackedSize;
|
||||
using BPackedSize = PipelineImplBase::BPackedSize;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t BlockScaleSize = Problem::kBlockScaleSize;
|
||||
static constexpr index_t KPerBlockScale =
|
||||
BlockGemmShape::kK * APackedSize / (BlockScaleSize * KXdlPack);
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
|
||||
@@ -324,22 +325,26 @@ struct GemmMXPipelineAgBgCrCompV3 : public BaseGemmMXPipelineAgBgCrCompV3<Proble
|
||||
auto b_scale_copy_dram_window =
|
||||
Base::GetBScaleDramLoadWindow(b_scale_dram_block_window_tmp);
|
||||
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution());
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
using AScaleBlockTileDistr = decltype(a_scale_copy_dram_window.get_tile_distribution());
|
||||
using BScaleBlockTileDistr = decltype(b_scale_copy_dram_window.get_tile_distribution());
|
||||
|
||||
using ABlockTile =
|
||||
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile =
|
||||
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
using AQBlockTile =
|
||||
decltype(make_static_distributed_tensor<AQDataType>(AQBlockTileDistr{}));
|
||||
using AScaleBlockTile =
|
||||
decltype(make_static_distributed_tensor<AScaleDataType>(AScaleBlockTileDistr{}));
|
||||
using BScaleBlockTile =
|
||||
decltype(make_static_distributed_tensor<BScaleDataType>(BScaleBlockTileDistr{}));
|
||||
|
||||
auto block_gemm = BlockGemm();
|
||||
|
||||
ABlockTile a_block_tile;
|
||||
BBlockTile b_block_tile;
|
||||
AQBlockTile aq_block_tile[2];
|
||||
AScaleBlockTile a_scale_block_tile[2];
|
||||
BScaleBlockTile b_scale_block_tile[2];
|
||||
int currIdx = 0;
|
||||
|
||||
auto c_block_tile = block_gemm.MakeCBlockTile();
|
||||
@@ -348,14 +353,22 @@ struct GemmMXPipelineAgBgCrCompV3 : public BaseGemmMXPipelineAgBgCrCompV3<Proble
|
||||
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
constexpr AQDramTileWindowStep aq_dram_tile_window_step =
|
||||
is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ);
|
||||
constexpr AScaleDramTileWindowStep a_scale_dram_tile_window_step =
|
||||
is_a_scale_col_major ? make_array(KPerBlockScale, 0)
|
||||
: make_array(0, KPerBlockScale);
|
||||
constexpr BScaleBlockTileWindowStep b_scale_dram_tile_window_step =
|
||||
is_b_scale_row_major ? make_array(KPerBlockScale, 0)
|
||||
: make_array(0, KPerBlockScale);
|
||||
|
||||
// DRAM prefetch (global read 0)
|
||||
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(
|
||||
aq_block_tile[currIdx], aq_copy_dram_window, aq_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(a_scale_block_tile[currIdx],
|
||||
a_scale_copy_dram_window,
|
||||
a_scale_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(b_scale_block_tile[currIdx],
|
||||
b_scale_copy_dram_window,
|
||||
b_scale_dram_tile_window_step);
|
||||
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user