mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
fix flatmm syntax error on gfx950
This commit is contained in:
@@ -25,19 +25,23 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber tail_num)
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool, TailNumber tail_num)
|
||||
{
|
||||
if(TailNumber::Even == tail_num)
|
||||
{
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Even>{});
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Even>{});
|
||||
}
|
||||
else if(TailNumber::Odd == tail_num)
|
||||
{
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
// assert(false);
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
|
||||
// return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
|
||||
// return run_func(bool_constant<true>{}, integral_constant<TailNumber,
|
||||
// TailNumber::Empty>{});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -57,7 +61,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
using BlockFlatmm =
|
||||
remove_cvref_t<decltype(PipelinePolicy::template GetBlockFlatmm<Problem>())>;
|
||||
|
||||
static constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
static constexpr auto config =
|
||||
BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
@@ -545,9 +550,9 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TailHotLoopScheduler()
|
||||
{
|
||||
@@ -638,8 +643,10 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
constexpr auto a_lds_block_desc =
|
||||
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
|
||||
|
||||
auto a_lds_block_ping = make_tensor_view<address_space_enum::lds>(p_a_lds_ping, a_lds_block_desc);
|
||||
auto a_lds_block_pong = make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
|
||||
auto a_lds_block_ping =
|
||||
make_tensor_view<address_space_enum::lds>(p_a_lds_ping, a_lds_block_desc);
|
||||
auto a_lds_block_pong =
|
||||
make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
|
||||
|
||||
// A DRAM tile window for load
|
||||
#ifndef FINEGRADE_LOADSTORE
|
||||
@@ -673,27 +680,27 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
move_tile_window(a_copy_dram_window(AIter), {AIter * AcopyPerLoadM, 0});
|
||||
});
|
||||
|
||||
auto a_copy_lds_window_ping_tmp = make_tile_window(
|
||||
a_lds_block_ping,
|
||||
auto a_copy_lds_window_ping_tmp =
|
||||
make_tile_window(a_lds_block_ping,
|
||||
make_tuple(number<AcopyPerLoadM>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeADramDistribution<Problem>()
|
||||
);
|
||||
PipelinePolicy::template MakeADramDistribution<Problem>());
|
||||
|
||||
statically_indexed_array<decltype(a_copy_lds_window_ping_tmp), ACopyLoadNum> a_copy_lds_window_ping;
|
||||
statically_indexed_array<decltype(a_copy_lds_window_ping_tmp), ACopyLoadNum>
|
||||
a_copy_lds_window_ping;
|
||||
static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) {
|
||||
a_copy_lds_window_ping(AIter) = a_copy_lds_window_ping_tmp;
|
||||
move_tile_window(a_copy_lds_window_ping(AIter), {AIter * AcopyPerLoadM, 0});
|
||||
});
|
||||
|
||||
auto a_copy_lds_window_pong_tmp = make_tile_window(
|
||||
a_lds_block_pong,
|
||||
auto a_copy_lds_window_pong_tmp =
|
||||
make_tile_window(a_lds_block_pong,
|
||||
make_tuple(number<AcopyPerLoadM>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeADramDistribution<Problem>()
|
||||
);
|
||||
PipelinePolicy::template MakeADramDistribution<Problem>());
|
||||
|
||||
statically_indexed_array<decltype(a_copy_lds_window_pong_tmp), ACopyLoadNum> a_copy_lds_window_pong;
|
||||
statically_indexed_array<decltype(a_copy_lds_window_pong_tmp), ACopyLoadNum>
|
||||
a_copy_lds_window_pong;
|
||||
static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) {
|
||||
a_copy_lds_window_pong(AIter) = a_copy_lds_window_pong_tmp;
|
||||
move_tile_window(a_copy_lds_window_pong(AIter), {AIter * AcopyPerLoadM, 0});
|
||||
@@ -705,14 +712,14 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
// a_lds_block_ping, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// ping-pong window for A LDS
|
||||
auto a_warp_window_ping_tmp = make_tile_window(
|
||||
a_lds_block_ping,
|
||||
auto a_warp_window_ping_tmp =
|
||||
make_tile_window(a_lds_block_ping,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{iMWarp * WG::kM, 0},
|
||||
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
|
||||
|
||||
auto a_warp_window_pong_tmp = make_tile_window(
|
||||
a_lds_block_pong,
|
||||
auto a_warp_window_pong_tmp =
|
||||
make_tile_window(a_lds_block_pong,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{iMWarp * WG::kM, 0},
|
||||
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
|
||||
@@ -776,14 +783,14 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_pong;
|
||||
|
||||
|
||||
// Prefetch A0
|
||||
#ifndef FINEGRADE_LOADSTORE
|
||||
auto a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
#else
|
||||
statically_indexed_array<decltype(load_tile(a_copy_dram_window(number<0>{}))), ACopyLoadNum> a_block_tile;
|
||||
statically_indexed_array<decltype(load_tile(a_copy_dram_window(number<0>{}))), ACopyLoadNum>
|
||||
a_block_tile;
|
||||
static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) {
|
||||
a_block_tile(AIter) = load_tile(a_copy_dram_window(AIter));
|
||||
move_tile_window(a_copy_dram_window(AIter), {0, kKPerBlock});
|
||||
@@ -815,14 +822,16 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// store_tile(a_copy_lds_window_ping, tile_elementwise_in(a_element_func, a_block_tile));
|
||||
// store_tile(a_copy_lds_window_ping, tile_elementwise_in(a_element_func,
|
||||
// a_block_tile));
|
||||
// }
|
||||
#ifndef FINEGRADE_LOADSTORE
|
||||
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
|
||||
#else
|
||||
static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) {
|
||||
store_tile(a_copy_lds_window_ping(AIter), tile_elementwise_in(a_element_func, a_block_tile(AIter)));
|
||||
store_tile(a_copy_lds_window_ping(AIter),
|
||||
tile_elementwise_in(a_element_func, a_block_tile(AIter)));
|
||||
});
|
||||
#endif
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -846,26 +855,34 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
|
||||
// preload A00,A10 from lds
|
||||
constexpr auto m_preload = (MIterPerWarp * KIterPerWarp >= 2) ? 2 : 1;
|
||||
statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))), m_preload> a_warp_tensor_ping;
|
||||
statically_indexed_array<decltype(load_tile(a_warp_windows_pong(number<0>{})(number<0>{}))), m_preload> a_warp_tensor_pong;
|
||||
statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))),
|
||||
m_preload>
|
||||
a_warp_tensor_ping;
|
||||
statically_indexed_array<decltype(load_tile(a_warp_windows_pong(number<0>{})(number<0>{}))),
|
||||
m_preload>
|
||||
a_warp_tensor_pong;
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor_ping(loadIter) = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
a_warp_tensor_ping(loadIter) =
|
||||
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// if(threadIdx.x==0){
|
||||
// for(int i=0;i<a_block_tile.get_thread_buffer_size();i++) {
|
||||
// printf("dteng--A buffer load: idx.x=%u, ablocktile=%f, buffer size=%d\n", threadIdx.x, type_convert<float>(a_block_tile.thread_buf_(i)),a_block_tile.get_thread_buffer_size());
|
||||
// printf("dteng--A buffer load: idx.x=%u, ablocktile=%f, buffer size=%d\n",
|
||||
// threadIdx.x,
|
||||
// type_convert<float>(a_block_tile.thread_buf_(i)),a_block_tile.get_thread_buffer_size());
|
||||
// }
|
||||
// }
|
||||
// for(int i=0;i<a_warp_tensor_ping(number<0>{}).get_thread_buffer_size();i++) {
|
||||
// printf("dteng--A lds load 00: idx.x=%u, awarptensor=%f, buffer size=%d\n", threadIdx.x, type_convert<float>(a_warp_tensor_ping(number<0>{}).thread_buf_(i)),a_warp_tensor_ping(number<0>{}).get_thread_buffer_size());
|
||||
// printf("dteng--A lds load 00: idx.x=%u, awarptensor=%f, buffer size=%d\n",
|
||||
// threadIdx.x,
|
||||
// type_convert<float>(a_warp_tensor_ping(number<0>{}).thread_buf_(i)),a_warp_tensor_ping(number<0>{}).get_thread_buffer_size());
|
||||
// }
|
||||
|
||||
|
||||
index_t iCounter = (num_loop - 1) / 2;
|
||||
// if constexpr(HasMainLoop)
|
||||
// {
|
||||
@@ -907,7 +924,9 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor_ping(number<AwarpIter>{}), b_warp_tensor_ping(nIter)(kIter));
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor_ping(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
@@ -918,37 +937,52 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
#ifdef FINEGRADE_LOADSTORE
|
||||
// prefetch B(2i+1)
|
||||
constexpr auto curMNIter = mIter * NIterPerWarp + nIter;
|
||||
if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1))
|
||||
if constexpr((curMNIter < NIterPerWarp * BloadGap) &&
|
||||
((curMNIter % BloadGap) == 1))
|
||||
{
|
||||
constexpr auto BnIter = curMNIter / BloadGap;
|
||||
constexpr auto BkIter = kIter;
|
||||
b_flat_dram_windows(number<BnIter>{})(BkIter) = b_flat_dram_window;
|
||||
move_tile_window(b_flat_dram_windows(number<BnIter>{})(BkIter),
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(number<BnIter>{})(BkIter),
|
||||
{BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter});
|
||||
b_warp_tensor_pong(number<BnIter>{})(BkIter) = load_tile(b_flat_dram_windows(number<BnIter>{})(BkIter));
|
||||
b_warp_tensor_pong(number<BnIter>{})(BkIter) =
|
||||
load_tile(b_flat_dram_windows(number<BnIter>{})(BkIter));
|
||||
}
|
||||
// Prefill A(2i+1)
|
||||
if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp)==0))
|
||||
if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) &&
|
||||
(mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp) == 0))
|
||||
{
|
||||
constexpr auto AIter = (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % ACopyLoadNum;
|
||||
store_tile(a_copy_lds_window_pong(number<AIter>{}), tile_elementwise_in(a_element_func, a_block_tile(number<AIter>{})));
|
||||
constexpr auto AIter =
|
||||
(mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) %
|
||||
ACopyLoadNum;
|
||||
store_tile(
|
||||
a_copy_lds_window_pong(number<AIter>{}),
|
||||
tile_elementwise_in(a_element_func, a_block_tile(number<AIter>{})));
|
||||
}
|
||||
// Prefetch A(2i+2)
|
||||
if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK + 1)) && (mIter < (MIterPerWarp - 1 + 1)) && ((nIter % NIterPerWarp)==(NIterPerWarp-2)))
|
||||
if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK + 1)) &&
|
||||
(mIter < (MIterPerWarp - 1 + 1)) &&
|
||||
((nIter % NIterPerWarp) == (NIterPerWarp - 2)))
|
||||
{
|
||||
constexpr auto AIter = (mIter + ACopyLoadNumPerK + kIter * ACopyLoadNumPerK) % ACopyLoadNum;
|
||||
a_block_tile(number<AIter>{}) = load_tile(a_copy_dram_window(number<AIter>{}));
|
||||
constexpr auto AIter =
|
||||
(mIter + ACopyLoadNumPerK + kIter * ACopyLoadNumPerK) %
|
||||
ACopyLoadNum;
|
||||
a_block_tile(number<AIter>{}) =
|
||||
load_tile(a_copy_dram_window(number<AIter>{}));
|
||||
move_tile_window(a_copy_dram_window(number<AIter>{}), {0, kKPerBlock});
|
||||
}
|
||||
#endif
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor_ping(number<AwarpIter>{}) = load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
a_warp_tensor_ping(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
@@ -966,7 +1000,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor_pong(loadIter) = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
a_warp_tensor_pong(loadIter) =
|
||||
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
HotLoopScheduler();
|
||||
|
||||
@@ -1007,7 +1042,9 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor_pong(number<AwarpIter>{}), b_warp_tensor_pong(nIter)(kIter));
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor_pong(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
@@ -1018,37 +1055,52 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
#ifdef FINEGRADE_LOADSTORE
|
||||
// prefetch B(2i+2)
|
||||
constexpr auto curMNIter = mIter * NIterPerWarp + nIter;
|
||||
if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1))
|
||||
if constexpr((curMNIter < NIterPerWarp * BloadGap) &&
|
||||
((curMNIter % BloadGap) == 1))
|
||||
{
|
||||
constexpr auto BnIter = curMNIter / BloadGap;
|
||||
constexpr auto BkIter = kIter;
|
||||
b_flat_dram_windows(number<BnIter>{})(BkIter) = b_flat_dram_window;
|
||||
move_tile_window(b_flat_dram_windows(number<BnIter>{})(BkIter),
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(number<BnIter>{})(BkIter),
|
||||
{BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter});
|
||||
b_warp_tensor_ping(number<BnIter>{})(BkIter) = load_tile(b_flat_dram_windows(number<BnIter>{})(BkIter));
|
||||
b_warp_tensor_ping(number<BnIter>{})(BkIter) =
|
||||
load_tile(b_flat_dram_windows(number<BnIter>{})(BkIter));
|
||||
}
|
||||
// Prefill A(2i+1)
|
||||
if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp)==0))
|
||||
if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) &&
|
||||
(mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp) == 0))
|
||||
{
|
||||
constexpr auto AIter = (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % ACopyLoadNum;
|
||||
store_tile(a_copy_lds_window_ping(number<AIter>{}), tile_elementwise_in(a_element_func, a_block_tile(number<AIter>{})));
|
||||
constexpr auto AIter =
|
||||
(mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) %
|
||||
ACopyLoadNum;
|
||||
store_tile(
|
||||
a_copy_lds_window_ping(number<AIter>{}),
|
||||
tile_elementwise_in(a_element_func, a_block_tile(number<AIter>{})));
|
||||
}
|
||||
// Prefetch A(2i+2)
|
||||
if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK + 1)) && (mIter < (MIterPerWarp - 1 + 1)) && ((nIter % NIterPerWarp)==(NIterPerWarp-2)))
|
||||
if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK + 1)) &&
|
||||
(mIter < (MIterPerWarp - 1 + 1)) &&
|
||||
((nIter % NIterPerWarp) == (NIterPerWarp - 2)))
|
||||
{
|
||||
constexpr auto AIter = (mIter + ACopyLoadNumPerK + kIter * ACopyLoadNumPerK) % ACopyLoadNum;
|
||||
a_block_tile(number<AIter>{}) = load_tile(a_copy_dram_window(number<AIter>{}));
|
||||
constexpr auto AIter =
|
||||
(mIter + ACopyLoadNumPerK + kIter * ACopyLoadNumPerK) %
|
||||
ACopyLoadNum;
|
||||
a_block_tile(number<AIter>{}) =
|
||||
load_tile(a_copy_dram_window(number<AIter>{}));
|
||||
move_tile_window(a_copy_dram_window(number<AIter>{}), {0, kKPerBlock});
|
||||
}
|
||||
#endif
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor_pong(number<AwarpIter>{}) = load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
a_warp_tensor_pong(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
@@ -1066,7 +1118,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor_ping(loadIter) = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
a_warp_tensor_ping(loadIter) =
|
||||
load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
HotLoopScheduler();
|
||||
|
||||
@@ -1108,7 +1161,9 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor_ping(number<AwarpIter>{}), b_warp_tensor_ping(nIter)(kIter));
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor_ping(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
@@ -1119,30 +1174,40 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
#ifdef FINEGRADE_LOADSTORE
|
||||
// prefetch B(loopK)
|
||||
constexpr auto curMNIter = mIter * NIterPerWarp + nIter;
|
||||
if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1))
|
||||
if constexpr((curMNIter < NIterPerWarp * BloadGap) &&
|
||||
((curMNIter % BloadGap) == 1))
|
||||
{
|
||||
constexpr auto BnIter = curMNIter / BloadGap;
|
||||
constexpr auto BkIter = kIter;
|
||||
b_flat_dram_windows(number<BnIter>{})(BkIter) = b_flat_dram_window;
|
||||
move_tile_window(b_flat_dram_windows(number<BnIter>{})(BkIter),
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(number<BnIter>{})(BkIter),
|
||||
{BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter});
|
||||
b_warp_tensor_pong(number<BnIter>{})(BkIter) = load_tile(b_flat_dram_windows(number<BnIter>{})(BkIter));
|
||||
b_warp_tensor_pong(number<BnIter>{})(BkIter) =
|
||||
load_tile(b_flat_dram_windows(number<BnIter>{})(BkIter));
|
||||
}
|
||||
// Prefill A(loopK)
|
||||
if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp)==0))
|
||||
if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) &&
|
||||
(mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp) == 0))
|
||||
{
|
||||
constexpr auto AIter = (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % ACopyLoadNum;
|
||||
store_tile(a_copy_lds_window_pong(number<AIter>{}), tile_elementwise_in(a_element_func, a_block_tile(number<AIter>{})));
|
||||
constexpr auto AIter =
|
||||
(mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) %
|
||||
ACopyLoadNum;
|
||||
store_tile(
|
||||
a_copy_lds_window_pong(number<AIter>{}),
|
||||
tile_elementwise_in(a_element_func, a_block_tile(number<AIter>{})));
|
||||
}
|
||||
#endif
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor_ping(number<AwarpIter>{}) = load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
a_warp_tensor_ping(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
@@ -1159,7 +1224,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor_pong(loadIter) = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
a_warp_tensor_pong(loadIter) =
|
||||
load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
@@ -1177,7 +1243,9 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor_pong(number<AwarpIter>{}), b_warp_tensor_pong(nIter)(kIter));
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor_pong(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
@@ -1186,11 +1254,13 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor_pong(number<AwarpIter>{}) = load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
a_warp_tensor_pong(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -1214,7 +1284,9 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor_ping(number<AwarpIter>{}), b_warp_tensor_ping(nIter)(kIter));
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor_ping(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
@@ -1225,30 +1297,40 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
#ifdef FINEGRADE_LOADSTORE
|
||||
// prefetch B(loopK)
|
||||
constexpr auto curMNIter = mIter * NIterPerWarp + nIter;
|
||||
if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1))
|
||||
if constexpr((curMNIter < NIterPerWarp * BloadGap) &&
|
||||
((curMNIter % BloadGap) == 1))
|
||||
{
|
||||
constexpr auto BnIter = curMNIter / BloadGap;
|
||||
constexpr auto BkIter = kIter;
|
||||
b_flat_dram_windows(number<BnIter>{})(BkIter) = b_flat_dram_window;
|
||||
move_tile_window(b_flat_dram_windows(number<BnIter>{})(BkIter),
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(number<BnIter>{})(BkIter),
|
||||
{BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter});
|
||||
b_warp_tensor_pong(number<BnIter>{})(BkIter) = load_tile(b_flat_dram_windows(number<BnIter>{})(BkIter));
|
||||
b_warp_tensor_pong(number<BnIter>{})(BkIter) =
|
||||
load_tile(b_flat_dram_windows(number<BnIter>{})(BkIter));
|
||||
}
|
||||
// Prefill A(loopK)
|
||||
if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp)==0))
|
||||
if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) &&
|
||||
(mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp) == 0))
|
||||
{
|
||||
constexpr auto AIter = (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % ACopyLoadNum;
|
||||
store_tile(a_copy_lds_window_pong(number<AIter>{}), tile_elementwise_in(a_element_func, a_block_tile(number<AIter>{})));
|
||||
constexpr auto AIter =
|
||||
(mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) %
|
||||
ACopyLoadNum;
|
||||
store_tile(
|
||||
a_copy_lds_window_pong(number<AIter>{}),
|
||||
tile_elementwise_in(a_element_func, a_block_tile(number<AIter>{})));
|
||||
}
|
||||
#endif
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor_ping(number<AwarpIter>{}) = load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
a_warp_tensor_ping(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
|
||||
Reference in New Issue
Block a user