fix flatmm syntax error on gfx950

This commit is contained in:
Feng Shijie
2025-07-23 19:12:31 +00:00
parent 5a1183ebbd
commit b908f5e803

View File

@@ -25,19 +25,23 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
}
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber tail_num)
CK_TILE_HOST_DEVICE static auto
TailHandler(const RunFunction& run_func, bool, TailNumber tail_num)
{
if(TailNumber::Even == tail_num)
{
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Even>{});
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Even>{});
}
else if(TailNumber::Odd == tail_num)
{
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Odd>{});
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Odd>{});
}
// assert(false);
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
// return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
// return run_func(bool_constant<true>{}, integral_constant<TailNumber,
// TailNumber::Empty>{});
}
};
@@ -57,7 +61,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
using BlockFlatmm =
remove_cvref_t<decltype(PipelinePolicy::template GetBlockFlatmm<Problem>())>;
static constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
static constexpr auto config =
BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
@@ -545,9 +550,9 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_barrier(0);
}
}
#endif
}
}
CK_TILE_HOST_DEVICE static constexpr auto TailHotLoopScheduler()
{
@@ -638,8 +643,10 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
constexpr auto a_lds_block_desc =
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block_ping = make_tensor_view<address_space_enum::lds>(p_a_lds_ping, a_lds_block_desc);
auto a_lds_block_pong = make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
auto a_lds_block_ping =
make_tensor_view<address_space_enum::lds>(p_a_lds_ping, a_lds_block_desc);
auto a_lds_block_pong =
make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
// A DRAM tile window for load
#ifndef FINEGRADE_LOADSTORE
@@ -673,27 +680,27 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
move_tile_window(a_copy_dram_window(AIter), {AIter * AcopyPerLoadM, 0});
});
auto a_copy_lds_window_ping_tmp = make_tile_window(
a_lds_block_ping,
auto a_copy_lds_window_ping_tmp =
make_tile_window(a_lds_block_ping,
make_tuple(number<AcopyPerLoadM>{}, number<kKPerBlock>{}),
{0, 0},
PipelinePolicy::template MakeADramDistribution<Problem>()
);
PipelinePolicy::template MakeADramDistribution<Problem>());
statically_indexed_array<decltype(a_copy_lds_window_ping_tmp), ACopyLoadNum> a_copy_lds_window_ping;
statically_indexed_array<decltype(a_copy_lds_window_ping_tmp), ACopyLoadNum>
a_copy_lds_window_ping;
static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) {
a_copy_lds_window_ping(AIter) = a_copy_lds_window_ping_tmp;
move_tile_window(a_copy_lds_window_ping(AIter), {AIter * AcopyPerLoadM, 0});
});
auto a_copy_lds_window_pong_tmp = make_tile_window(
a_lds_block_pong,
auto a_copy_lds_window_pong_tmp =
make_tile_window(a_lds_block_pong,
make_tuple(number<AcopyPerLoadM>{}, number<kKPerBlock>{}),
{0, 0},
PipelinePolicy::template MakeADramDistribution<Problem>()
);
PipelinePolicy::template MakeADramDistribution<Problem>());
statically_indexed_array<decltype(a_copy_lds_window_pong_tmp), ACopyLoadNum> a_copy_lds_window_pong;
statically_indexed_array<decltype(a_copy_lds_window_pong_tmp), ACopyLoadNum>
a_copy_lds_window_pong;
static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) {
a_copy_lds_window_pong(AIter) = a_copy_lds_window_pong_tmp;
move_tile_window(a_copy_lds_window_pong(AIter), {AIter * AcopyPerLoadM, 0});
@@ -705,14 +712,14 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
// a_lds_block_ping, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// ping-pong window for A LDS
auto a_warp_window_ping_tmp = make_tile_window(
a_lds_block_ping,
auto a_warp_window_ping_tmp =
make_tile_window(a_lds_block_ping,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
auto a_warp_window_pong_tmp = make_tile_window(
a_lds_block_pong,
auto a_warp_window_pong_tmp =
make_tile_window(a_lds_block_pong,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
@@ -776,14 +783,14 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
NIterPerWarp>
b_warp_tensor_pong;
// Prefetch A0
#ifndef FINEGRADE_LOADSTORE
auto a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
#else
statically_indexed_array<decltype(load_tile(a_copy_dram_window(number<0>{}))), ACopyLoadNum> a_block_tile;
statically_indexed_array<decltype(load_tile(a_copy_dram_window(number<0>{}))), ACopyLoadNum>
a_block_tile;
static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) {
a_block_tile(AIter) = load_tile(a_copy_dram_window(AIter));
move_tile_window(a_copy_dram_window(AIter), {0, kKPerBlock});
@@ -815,14 +822,16 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
// }
// else
// {
// store_tile(a_copy_lds_window_ping, tile_elementwise_in(a_element_func, a_block_tile));
// store_tile(a_copy_lds_window_ping, tile_elementwise_in(a_element_func,
// a_block_tile));
// }
#ifndef FINEGRADE_LOADSTORE
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
#else
static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) {
store_tile(a_copy_lds_window_ping(AIter), tile_elementwise_in(a_element_func, a_block_tile(AIter)));
store_tile(a_copy_lds_window_ping(AIter),
tile_elementwise_in(a_element_func, a_block_tile(AIter)));
});
#endif
__builtin_amdgcn_sched_barrier(0);
@@ -846,26 +855,34 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
// preload A00,A10 from lds
constexpr auto m_preload = (MIterPerWarp * KIterPerWarp >= 2) ? 2 : 1;
statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))), m_preload> a_warp_tensor_ping;
statically_indexed_array<decltype(load_tile(a_warp_windows_pong(number<0>{})(number<0>{}))), m_preload> a_warp_tensor_pong;
statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))),
m_preload>
a_warp_tensor_ping;
statically_indexed_array<decltype(load_tile(a_warp_windows_pong(number<0>{})(number<0>{}))),
m_preload>
a_warp_tensor_pong;
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor_ping(loadIter) = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
a_warp_tensor_ping(loadIter) =
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
});
__builtin_amdgcn_sched_barrier(0);
// if(threadIdx.x==0){
// for(int i=0;i<a_block_tile.get_thread_buffer_size();i++) {
// printf("dteng--A buffer load: idx.x=%u, ablocktile=%f, buffer size=%d\n", threadIdx.x, type_convert<float>(a_block_tile.thread_buf_(i)),a_block_tile.get_thread_buffer_size());
// printf("dteng--A buffer load: idx.x=%u, ablocktile=%f, buffer size=%d\n",
// threadIdx.x,
// type_convert<float>(a_block_tile.thread_buf_(i)),a_block_tile.get_thread_buffer_size());
// }
// }
// for(int i=0;i<a_warp_tensor_ping(number<0>{}).get_thread_buffer_size();i++) {
// printf("dteng--A lds load 00: idx.x=%u, awarptensor=%f, buffer size=%d\n", threadIdx.x, type_convert<float>(a_warp_tensor_ping(number<0>{}).thread_buf_(i)),a_warp_tensor_ping(number<0>{}).get_thread_buffer_size());
// printf("dteng--A lds load 00: idx.x=%u, awarptensor=%f, buffer size=%d\n",
// threadIdx.x,
// type_convert<float>(a_warp_tensor_ping(number<0>{}).thread_buf_(i)),a_warp_tensor_ping(number<0>{}).get_thread_buffer_size());
// }
index_t iCounter = (num_loop - 1) / 2;
// if constexpr(HasMainLoop)
// {
@@ -907,7 +924,9 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor_ping(number<AwarpIter>{}), b_warp_tensor_ping(nIter)(kIter));
WG{}(c_warp_tensor,
a_warp_tensor_ping(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
@@ -918,37 +937,52 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
#ifdef FINEGRADE_LOADSTORE
// prefetch B(2i+1)
constexpr auto curMNIter = mIter * NIterPerWarp + nIter;
if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1))
if constexpr((curMNIter < NIterPerWarp * BloadGap) &&
((curMNIter % BloadGap) == 1))
{
constexpr auto BnIter = curMNIter / BloadGap;
constexpr auto BkIter = kIter;
b_flat_dram_windows(number<BnIter>{})(BkIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(number<BnIter>{})(BkIter),
move_tile_window(
b_flat_dram_windows(number<BnIter>{})(BkIter),
{BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter});
b_warp_tensor_pong(number<BnIter>{})(BkIter) = load_tile(b_flat_dram_windows(number<BnIter>{})(BkIter));
b_warp_tensor_pong(number<BnIter>{})(BkIter) =
load_tile(b_flat_dram_windows(number<BnIter>{})(BkIter));
}
// Prefill A(2i+1)
if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp)==0))
if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) &&
(mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp) == 0))
{
constexpr auto AIter = (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % ACopyLoadNum;
store_tile(a_copy_lds_window_pong(number<AIter>{}), tile_elementwise_in(a_element_func, a_block_tile(number<AIter>{})));
constexpr auto AIter =
(mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) %
ACopyLoadNum;
store_tile(
a_copy_lds_window_pong(number<AIter>{}),
tile_elementwise_in(a_element_func, a_block_tile(number<AIter>{})));
}
// Prefetch A(2i+2)
if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK + 1)) && (mIter < (MIterPerWarp - 1 + 1)) && ((nIter % NIterPerWarp)==(NIterPerWarp-2)))
if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK + 1)) &&
(mIter < (MIterPerWarp - 1 + 1)) &&
((nIter % NIterPerWarp) == (NIterPerWarp - 2)))
{
constexpr auto AIter = (mIter + ACopyLoadNumPerK + kIter * ACopyLoadNumPerK) % ACopyLoadNum;
a_block_tile(number<AIter>{}) = load_tile(a_copy_dram_window(number<AIter>{}));
constexpr auto AIter =
(mIter + ACopyLoadNumPerK + kIter * ACopyLoadNumPerK) %
ACopyLoadNum;
a_block_tile(number<AIter>{}) =
load_tile(a_copy_dram_window(number<AIter>{}));
move_tile_window(a_copy_dram_window(number<AIter>{}), {0, kKPerBlock});
}
#endif
__builtin_amdgcn_sched_barrier(0x7F6);
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor_ping(number<AwarpIter>{}) = load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
a_warp_tensor_ping(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
// barrier
@@ -966,7 +1000,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor_pong(loadIter) = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
a_warp_tensor_pong(loadIter) =
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
});
HotLoopScheduler();
@@ -1007,7 +1042,9 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor_pong(number<AwarpIter>{}), b_warp_tensor_pong(nIter)(kIter));
WG{}(c_warp_tensor,
a_warp_tensor_pong(number<AwarpIter>{}),
b_warp_tensor_pong(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
@@ -1018,37 +1055,52 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
#ifdef FINEGRADE_LOADSTORE
// prefetch B(2i+2)
constexpr auto curMNIter = mIter * NIterPerWarp + nIter;
if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1))
if constexpr((curMNIter < NIterPerWarp * BloadGap) &&
((curMNIter % BloadGap) == 1))
{
constexpr auto BnIter = curMNIter / BloadGap;
constexpr auto BkIter = kIter;
b_flat_dram_windows(number<BnIter>{})(BkIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(number<BnIter>{})(BkIter),
move_tile_window(
b_flat_dram_windows(number<BnIter>{})(BkIter),
{BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter});
b_warp_tensor_ping(number<BnIter>{})(BkIter) = load_tile(b_flat_dram_windows(number<BnIter>{})(BkIter));
b_warp_tensor_ping(number<BnIter>{})(BkIter) =
load_tile(b_flat_dram_windows(number<BnIter>{})(BkIter));
}
// Prefill A(2i+1)
if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp)==0))
if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) &&
(mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp) == 0))
{
constexpr auto AIter = (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % ACopyLoadNum;
store_tile(a_copy_lds_window_ping(number<AIter>{}), tile_elementwise_in(a_element_func, a_block_tile(number<AIter>{})));
constexpr auto AIter =
(mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) %
ACopyLoadNum;
store_tile(
a_copy_lds_window_ping(number<AIter>{}),
tile_elementwise_in(a_element_func, a_block_tile(number<AIter>{})));
}
// Prefetch A(2i+2)
if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK + 1)) && (mIter < (MIterPerWarp - 1 + 1)) && ((nIter % NIterPerWarp)==(NIterPerWarp-2)))
if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK + 1)) &&
(mIter < (MIterPerWarp - 1 + 1)) &&
((nIter % NIterPerWarp) == (NIterPerWarp - 2)))
{
constexpr auto AIter = (mIter + ACopyLoadNumPerK + kIter * ACopyLoadNumPerK) % ACopyLoadNum;
a_block_tile(number<AIter>{}) = load_tile(a_copy_dram_window(number<AIter>{}));
constexpr auto AIter =
(mIter + ACopyLoadNumPerK + kIter * ACopyLoadNumPerK) %
ACopyLoadNum;
a_block_tile(number<AIter>{}) =
load_tile(a_copy_dram_window(number<AIter>{}));
move_tile_window(a_copy_dram_window(number<AIter>{}), {0, kKPerBlock});
}
#endif
__builtin_amdgcn_sched_barrier(0x7F6);
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor_pong(number<AwarpIter>{}) = load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
a_warp_tensor_pong(number<AwarpIter>{}) =
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
}
// barrier
@@ -1066,7 +1118,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor_ping(loadIter) = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
a_warp_tensor_ping(loadIter) =
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
});
HotLoopScheduler();
@@ -1108,7 +1161,9 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor_ping(number<AwarpIter>{}), b_warp_tensor_ping(nIter)(kIter));
WG{}(c_warp_tensor,
a_warp_tensor_ping(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
@@ -1119,30 +1174,40 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
#ifdef FINEGRADE_LOADSTORE
// prefetch B(loopK)
constexpr auto curMNIter = mIter * NIterPerWarp + nIter;
if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1))
if constexpr((curMNIter < NIterPerWarp * BloadGap) &&
((curMNIter % BloadGap) == 1))
{
constexpr auto BnIter = curMNIter / BloadGap;
constexpr auto BkIter = kIter;
b_flat_dram_windows(number<BnIter>{})(BkIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(number<BnIter>{})(BkIter),
move_tile_window(
b_flat_dram_windows(number<BnIter>{})(BkIter),
{BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter});
b_warp_tensor_pong(number<BnIter>{})(BkIter) = load_tile(b_flat_dram_windows(number<BnIter>{})(BkIter));
b_warp_tensor_pong(number<BnIter>{})(BkIter) =
load_tile(b_flat_dram_windows(number<BnIter>{})(BkIter));
}
// Prefill A(loopK)
if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp)==0))
if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) &&
(mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp) == 0))
{
constexpr auto AIter = (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % ACopyLoadNum;
store_tile(a_copy_lds_window_pong(number<AIter>{}), tile_elementwise_in(a_element_func, a_block_tile(number<AIter>{})));
constexpr auto AIter =
(mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) %
ACopyLoadNum;
store_tile(
a_copy_lds_window_pong(number<AIter>{}),
tile_elementwise_in(a_element_func, a_block_tile(number<AIter>{})));
}
#endif
__builtin_amdgcn_sched_barrier(0x7F6);
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor_ping(number<AwarpIter>{}) = load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
a_warp_tensor_ping(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
// barrier
@@ -1159,7 +1224,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor_pong(loadIter) = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
a_warp_tensor_pong(loadIter) =
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
});
// __builtin_amdgcn_sched_barrier(0);
@@ -1177,7 +1243,9 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor_pong(number<AwarpIter>{}), b_warp_tensor_pong(nIter)(kIter));
WG{}(c_warp_tensor,
a_warp_tensor_pong(number<AwarpIter>{}),
b_warp_tensor_pong(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
@@ -1186,11 +1254,13 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
c_warp_tensor.get_thread_buffer());
__builtin_amdgcn_sched_barrier(0x7F6);
});
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor_pong(number<AwarpIter>{}) = load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
a_warp_tensor_pong(number<AwarpIter>{}) =
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
}
});
});
@@ -1214,7 +1284,9 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor_ping(number<AwarpIter>{}), b_warp_tensor_ping(nIter)(kIter));
WG{}(c_warp_tensor,
a_warp_tensor_ping(number<AwarpIter>{}),
b_warp_tensor_ping(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
@@ -1225,30 +1297,40 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
#ifdef FINEGRADE_LOADSTORE
// prefetch B(loopK)
constexpr auto curMNIter = mIter * NIterPerWarp + nIter;
if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1))
if constexpr((curMNIter < NIterPerWarp * BloadGap) &&
((curMNIter % BloadGap) == 1))
{
constexpr auto BnIter = curMNIter / BloadGap;
constexpr auto BkIter = kIter;
b_flat_dram_windows(number<BnIter>{})(BkIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(number<BnIter>{})(BkIter),
move_tile_window(
b_flat_dram_windows(number<BnIter>{})(BkIter),
{BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter});
b_warp_tensor_pong(number<BnIter>{})(BkIter) = load_tile(b_flat_dram_windows(number<BnIter>{})(BkIter));
b_warp_tensor_pong(number<BnIter>{})(BkIter) =
load_tile(b_flat_dram_windows(number<BnIter>{})(BkIter));
}
// Prefill A(loopK)
if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp)==0))
if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) &&
(mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp) == 0))
{
constexpr auto AIter = (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % ACopyLoadNum;
store_tile(a_copy_lds_window_pong(number<AIter>{}), tile_elementwise_in(a_element_func, a_block_tile(number<AIter>{})));
constexpr auto AIter =
(mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) %
ACopyLoadNum;
store_tile(
a_copy_lds_window_pong(number<AIter>{}),
tile_elementwise_in(a_element_func, a_block_tile(number<AIter>{})));
}
#endif
__builtin_amdgcn_sched_barrier(0x7F6);
});
// preload next A from lds
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
if constexpr((kIter * MIterPerWarp + mIter) <
(KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor_ping(number<AwarpIter>{}) = load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
a_warp_tensor_ping(number<AwarpIter>{}) =
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
// barrier