update gemm_mx_pipeline_ag_bg_cr_v3

This commit is contained in:
mtgu0705
2025-08-13 09:24:24 -05:00
parent 1698930818
commit c2edc3c5b8

View File

@@ -115,13 +115,6 @@ struct GemmMXPipelineAgBgCrCompV3 : public BaseGemmMXPipelineAgBgCrCompV3<Proble
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t BlockScaleSize = Problem::kBlockScaleSize;
static constexpr index_t KPerBlockScale = BlockGemmShape::kK / BlockScaleSize;
using MXdlPack = PipelineImplBase::MXdlPack;
using NXdlPack = PipelineImplBase::NXdlPack;
using KXdlPack = PipelineImplBase::KXdlPack;
@@ -129,6 +122,14 @@ struct GemmMXPipelineAgBgCrCompV3 : public BaseGemmMXPipelineAgBgCrCompV3<Proble
using APackedSize = PipelineImplBase::APackedSize;
using BPackedSize = PipelineImplBase::BPackedSize;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t BlockScaleSize = Problem::kBlockScaleSize;
static constexpr index_t KPerBlockScale =
BlockGemmShape::kK * APackedSize / (BlockScaleSize * KXdlPack);
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
@@ -324,22 +325,26 @@ struct GemmMXPipelineAgBgCrCompV3 : public BaseGemmMXPipelineAgBgCrCompV3<Proble
auto b_scale_copy_dram_window =
Base::GetBScaleDramLoadWindow(b_scale_dram_block_window_tmp);
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using AQBlockTileDistr = decltype(aq_copy_dram_window.get_tile_distribution());
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using AScaleBlockTileDistr = decltype(a_scale_copy_dram_window.get_tile_distribution());
using BScaleBlockTileDistr = decltype(b_scale_copy_dram_window.get_tile_distribution());
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile =
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
using AQBlockTile =
decltype(make_static_distributed_tensor<AQDataType>(AQBlockTileDistr{}));
using AScaleBlockTile =
decltype(make_static_distributed_tensor<AScaleDataType>(AScaleBlockTileDistr{}));
using BScaleBlockTile =
decltype(make_static_distributed_tensor<BScaleDataType>(BScaleBlockTileDistr{}));
auto block_gemm = BlockGemm();
ABlockTile a_block_tile;
BBlockTile b_block_tile;
AQBlockTile aq_block_tile[2];
AScaleBlockTile a_scale_block_tile[2];
BScaleBlockTile b_scale_block_tile[2];
int currIdx = 0;
auto c_block_tile = block_gemm.MakeCBlockTile();
@@ -348,14 +353,22 @@ struct GemmMXPipelineAgBgCrCompV3 : public BaseGemmMXPipelineAgBgCrCompV3<Proble
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr AQDramTileWindowStep aq_dram_tile_window_step =
is_aq_col_major ? make_array(KPerBlockAQ, 0) : make_array(0, KPerBlockAQ);
constexpr AScaleDramTileWindowStep a_scale_dram_tile_window_step =
is_a_scale_col_major ? make_array(KPerBlockScale, 0)
: make_array(0, KPerBlockScale);
constexpr BScaleBlockTileWindowStep b_scale_dram_tile_window_step =
is_b_scale_row_major ? make_array(KPerBlockScale, 0)
: make_array(0, KPerBlockScale);
// DRAM prefetch (global read 0)
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(
aq_block_tile[currIdx], aq_copy_dram_window, aq_dram_tile_window_step);
Base::GlobalPrefetch(a_scale_block_tile[currIdx],
a_scale_copy_dram_window,
a_scale_dram_tile_window_step);
Base::GlobalPrefetch(b_scale_block_tile[currIdx],
b_scale_copy_dram_window,
b_scale_dram_tile_window_step);
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);