Flatmm merge (#2168)

* sync with function interface of cshuffleepiloge,fix flatmm build fail

* move code from solin/flatmm which add mfma16*16*32fp8 and optimize flatmm

---------

Co-authored-by: solin <bingzhou@amd.com>
This commit is contained in:
BingYuan.Zhou
2025-05-08 12:59:57 +08:00
committed by GitHub
parent c7b8e86e34
commit 6a3960c1e1
11 changed files with 552 additions and 192 deletions

View File

@@ -73,6 +73,83 @@ struct FlatmmPipelineAGmemBGmemCRegV1
return PipelinePolicy::template GetSmemSize<Problem>();
}
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
{
constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
constexpr index_t KPerLoad = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t A_Buffer_Load_Inst_Num = kMPerBlock * kKPerBlock / BlockSize / KPerLoad;
constexpr index_t A_LDS_Read_Inst_Num = MIterPerWarp * KIterPerWarp;
constexpr index_t B_Buffer_Load_Inst_Num = NIterPerWarp * KIterPerWarp;
// constexpr index_t A_LDS_Read_Inst_Remain = A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num;
#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8)
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
});
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
});
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
});
#elif defined(USING_MFMA_32x32x16)
static_for<0,
A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - B_Buffer_Load_Inst_Num,
1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, A_LDS_Read_Inst_Num / 2, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
#endif
}
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, typename AElementFunction>
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
@@ -89,6 +166,25 @@ struct FlatmmPipelineAGmemBGmemCRegV1
static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
constexpr index_t KFlatPerBlockPerIter = flatKPerWarp;
constexpr index_t NFlatPerBlockPerIter = flatNPerWarp;
constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
const index_t iMWarp = get_warp_id() / NWarp;
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
@@ -112,6 +208,25 @@ struct FlatmmPipelineAGmemBGmemCRegV1
auto a_lds_gemm_window = make_tile_window(
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto a_warp_window_tmp = make_tile_window(
a_lds_gemm_window.get_bottom_tensor_view(),
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
a_lds_gemm_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
// Block GEMM
auto block_flatmm = BlockFlatmm();
@@ -126,16 +241,45 @@ struct FlatmmPipelineAGmemBGmemCRegV1
b_flat_distribution);
// Acc register tile
auto c_block_tile = decltype(block_flatmm(a_lds_gemm_window, b_flat_dram_window)){};
auto c_block_tile = block_flatmm.MakeCBlockTile();
// prefetch
// global read 0
auto a_block_tile = load_tile(a_copy_dram_window);
statically_indexed_array<
statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
NIterPerWarp>
b_flat_dram_windows;
statically_indexed_array<
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
NIterPerWarp>
b_warp_tensor;
statically_indexed_array<
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
NIterPerWarp>
b_warp_tensor_2;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});
{
// move to 1
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// move to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
@@ -152,40 +296,116 @@ struct FlatmmPipelineAGmemBGmemCRegV1
{
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
}
block_sync_lds();
}
index_t iCounter = num_loop - 1;
index_t iCounter = num_loop / 2 - 1;
while(iCounter > 0)
{
// global read i + 1
a_block_tile = load_tile(a_copy_dram_window);
block_sync_lds();
// GEMM i
block_flatmm(c_block_tile, a_lds_gemm_window, b_flat_dram_window);
block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor);
block_sync_lds();
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});
// move to i + 2
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// move to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
// LDS write i + 1
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window, a_block_tile_tmp);
HotLoopScheduler();
block_sync_lds();
// iCounter--;
// global read i + 1
a_block_tile = load_tile(a_copy_dram_window);
// GEMM i
block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_2);
block_sync_lds();
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});
// move to i + 2
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// move to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
// LDS write i + 1
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window, a_block_tile_tmp);
HotLoopScheduler();
block_sync_lds();
iCounter--;
}
// tail
{
// global read i + 1
a_block_tile = load_tile(a_copy_dram_window);
// GEMM i
block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor);
block_sync_lds();
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});
// move to i + 2
// move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// LDS write i + 1
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window, a_block_tile_tmp);
// move to next flat K
// move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
HotLoopScheduler();
block_sync_lds();
// GEMM num_loop - 1
block_flatmm(c_block_tile, a_lds_gemm_window, b_flat_dram_window);
block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_2);
}
return c_block_tile;

View File

@@ -19,23 +19,100 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using namespace ck_tile;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8)
/*reduce transform layers,compare with old ck*/
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackA<Problem>();
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}),
make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
number<8>{},
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(
make_xor_transform(make_tuple(number<MPerBlock>{}, number<KPerBlock / KPack>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
make_merge_transform_v3_division_mod(
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return a_lds_block_desc;
#elif defined(USING_MFMA_32x32x16)
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t kKPack = GetSmemPackA<Problem>();
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_pass_through_transform(kMPerBlock),
make_merge_transform(make_tuple(kKPerBlock / 8, 8))),
make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return a_lds_block_desc;
#endif
/*xor*/
#if 0
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t kKPack = GetSmemPackA<Problem>();
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr auto DataTypeSize = sizeof(ADataType);
constexpr auto MLdsLayer =
(32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack * MLdsLayer>{},
number<kMPerBlock / MLdsLayer>{},
number<kKPack>{}),
make_tuple(number<kKPack>{}, number<kKPerBlock * MLdsLayer>{}, number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<kMPerBlock / MLdsLayer>{},
number<kKPerBlock / kKPack * MLdsLayer>{})),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(
make_tuple(number<MLdsLayer>{}, number<kKPerBlock / kKPack>{})),
make_pass_through_transform(number<kMPerBlock / MLdsLayer>{}),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(make_merge_transform(
make_tuple(number<kMPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
make_merge_transform(
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
#endif
return a_lds_block_desc;
}
@@ -58,7 +135,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
{
return Problem::VectorLoadSize;
return Problem::VectorLoadSize / sizeof(typename Problem::ADataType);
}
template <typename Problem>
@@ -82,7 +159,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
constexpr index_t KPack = GetSmemPackA<Problem>();
static_assert(KPack % K3 == 0);
constexpr index_t K2 = KPack / K3;
if constexpr(get_warp_size() % (K2 * M0))
if constexpr(get_warp_size() >= (K2 * M0))
{
constexpr index_t K1 = get_warp_size() / (K2 * M0);
constexpr index_t K0 = BlockSize / get_warp_size();
@@ -209,7 +286,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
if constexpr(warp_size % (K2 * M0) == 0)
if constexpr(warp_size >= (K2 * M0))
{
constexpr index_t K1 = warp_size / (K2 * M0);
constexpr index_t K0 = kBlockSize / warp_size;