1.fix loop_num=odd bug 2.optimize mi300 performance of big MNK(tilesize 128x128x128) 3.optimize decode perf on mi300

This commit is contained in:
root
2025-06-25 20:41:25 -05:00
parent 40f1d5829e
commit 56f84349ca
3 changed files with 271 additions and 68 deletions

View File

@@ -145,8 +145,8 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
}
};
if(has_hot_loop)
{
// if(has_hot_loop)
// {
if(tail_num == ck_tile::TailNumber::Odd)
{
RunSplitk(ck_tile::bool_constant<true>{},
@@ -165,28 +165,28 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
}
else
{
if(tail_num == ck_tile::TailNumber::Odd)
{
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_num == ck_tile::TailNumber::Even)
{
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
std::ostringstream err;
err << "Num K loop must be larger than number of prefetech stages."
<< "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
}
// }
// else
// {
// if(tail_num == ck_tile::TailNumber::Odd)
// {
// RunSplitk(ck_tile::bool_constant<false>{},
// ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
// }
// else if(tail_num == ck_tile::TailNumber::Even)
// {
// RunSplitk(ck_tile::bool_constant<false>{},
// ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
// }
// else
// {
// std::ostringstream err;
// err << "Num K loop must be larger than number of prefetech stages."
// << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
// << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
// throw std::runtime_error(err.str());
// }
// }
return ave_time;
}

View File

@@ -161,9 +161,9 @@ struct GemmConfig
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
#elif defined(USING_MFMA_16x16x32_F8) //MI300 FP8 16X16
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;

View File

@@ -96,6 +96,39 @@ struct FlatmmPipelineAGmemBGmemCRegV1
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
/*
defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) // mi300 fp8 16c 0.5*K1
defined(USING_MFMA_32x32x16) && defined(ENABLE_FP8) // mi300 fp8 32c 0.5*K1
defined(USING_MFMA_16x16x16) && defined(ENABLE_FP16) // mi300 fp16 16c 0.5*K1
defined(USING_MFMA_32x32x8) && defined(ENABLE_FP16) // mi300 fp16 32c 0.5*K1
defined(USING_MFMA_16x16x128) && defined(ENABLE_FP8) // mi350 fp8 32c 2*K1
defined(USING_MFMA_32x32x64) && defined(ENABLE_FP8) // mi350 fp8 64c 2*K1
defined(USING_MFMA_16x16x32) && defined(ENABLE_FP16) // mi350 fp16 16c 1*K1
defined(USING_MFMA_32x32x16) && defined(ENABLE_FP16) // mi350 fp16 32c 1*K1
defined(USING_MFMA_16x16x128) && defined(ENABLE_FP4) // mi350 fp4 16c 1*K1
defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
*/
#if (defined(USING_MFMA_16x16x32_F8) || \
defined(USING_MFMA_32x32x16_F8) || \
defined(USING_MFMA_16x16x16_F16) || \
defined(USING_MFMA_32x32x8_F16)) // K1 per Mfma = 0.5
static constexpr auto mfma_per_wg = 2;
static constexpr auto dsread_per_wg = 1;
#elif (defined(USING_MFMA_16x16x32_F16) || \
defined(USING_MFMA_32x32x16_F16) || \
defined(USING_MFMA_16x16x128_F4) || \
defined(USING_MFMA_32x32x64_F4)) // K1 per Mfma = 1
static constexpr auto mfma_per_wg = 1;
static constexpr auto dsread_per_wg = 1;
#elif (defined(USING_MFMA_16x16x128_F8) || \
defined(USING_MFMA_32x32x64_F8)) // K1 per Mfma = 2
static constexpr auto mfma_per_wg = 1;
static constexpr auto dsread_per_wg = 2;
#endif
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
@@ -191,7 +224,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
// 0 M7N2: 63 - - 8 -
// 0 M7N3: 64 4 - - -
#if 1
#if 0 // MI350 FP8 16X16 128*256*256
static_for<0, 2, 1>{}([&](auto j) {
ignore = j;
static_for<0, 3, 1>{}([&](auto i) {
@@ -237,7 +270,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
__builtin_amdgcn_sched_barrier(0);
#endif
#if 0
#if 0 // MI350 FP8 16X16
static_for<0, 2, 1>{}([&](auto j) {
ignore = j;
static_for<0, 3, 1>{}([&](auto i) {
@@ -273,6 +306,166 @@ struct FlatmmPipelineAGmemBGmemCRegV1
__builtin_amdgcn_sched_barrier(0);
#endif
#if 0 // MI300 FP8 16X16 128*128*128
static_for<0, 2, 1>{}([&](auto j) {
ignore = j;
static_for<0, 2, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
static_for<0, 2, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
static_for<0, 1, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
static_for<0, 1, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
static_for<0, 1, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
static_for<0, 1, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
});
__builtin_amdgcn_sched_barrier(0);
#endif
#if 0 // MI300 FP8 16X16 128*256*128
static_for<0, 2, 1>{}([&](auto j) {
ignore = j;
static_for<0, 4, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
static_for<0, 1, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
static_for<0, 1, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
static_for<0, 1, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
static_for<0, 1, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
});
__builtin_amdgcn_sched_barrier(0);
#endif
#if 0 //MI300 FP8 16X16 16*64*256
static_for<0, 1, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_barrier(0);
#endif
}
@@ -340,6 +533,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
const index_t iMWarp = get_warp_id() / NWarp;
using CWarpDstr = typename WG::CWarpDstr;
@@ -565,11 +759,14 @@ struct FlatmmPipelineAGmemBGmemCRegV1
block_sync_lds();
// preload A00,A10 from lds
statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))), 2> a_warp_tensor_ping;
statically_indexed_array<decltype(load_tile(a_warp_windows_pong(number<0>{})(number<0>{}))), 2> a_warp_tensor_pong;
static_for<0, 2, 1>{}([&](auto mIter) {
a_warp_tensor_ping(mIter) = load_tile(a_warp_windows_ping(mIter)(number<0>{}));
constexpr auto m_preload = (MIterPerWarp * KIterPerWarp >= 2) ? 2: 1;
statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))), m_preload> a_warp_tensor_ping;
statically_indexed_array<decltype(load_tile(a_warp_windows_pong(number<0>{})(number<0>{}))), m_preload> a_warp_tensor_pong;
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor_ping(loadIter) = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
});
__builtin_amdgcn_sched_barrier(0);
@@ -583,7 +780,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
// }
index_t iCounter = num_loop / 2 - 1;
index_t iCounter = (num_loop - 1) / 2;
// if constexpr(HasMainLoop)
// {
while(iCounter > 0)
@@ -614,7 +811,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
// GEMM 2i
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = mIter % 2;
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
@@ -661,15 +858,15 @@ struct FlatmmPipelineAGmemBGmemCRegV1
__builtin_amdgcn_sched_barrier(0x7F6);
});
// preload next A from lds
if constexpr((kIter != KIterPerWarp - 1) || (mIter < (MIterPerWarp - 2)))
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + 2) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + 2) / MIterPerWarp);
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor_ping(number<AwarpIter>{}) = load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
//barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == (MIterPerWarp - 2)))
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
@@ -680,8 +877,10 @@ struct FlatmmPipelineAGmemBGmemCRegV1
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
static_for<0, 2, 1>{}([&](auto mIter) {
a_warp_tensor_pong(mIter) = load_tile(a_warp_windows_pong(mIter)(number<0>{}));
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor_pong(loadIter) = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
});
HotLoopScheduler();
@@ -713,7 +912,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
// GEMM 2i+1
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = mIter % 2;
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
@@ -759,15 +958,15 @@ struct FlatmmPipelineAGmemBGmemCRegV1
__builtin_amdgcn_sched_barrier(0x7F6);
});
// preload next A from lds
if constexpr((kIter!=KIterPerWarp-1)||(mIter<(MIterPerWarp-2)))
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + 2) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + 2) / MIterPerWarp);
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor_pong(number<AwarpIter>{}) = load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
}
//barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == (MIterPerWarp - 2)))
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
@@ -778,8 +977,10 @@ struct FlatmmPipelineAGmemBGmemCRegV1
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
static_for<0, 2, 1>{}([&](auto mIter) {
a_warp_tensor_ping(mIter) = load_tile(a_warp_windows_ping(mIter)(number<0>{}));
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor_ping(loadIter) = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
});
HotLoopScheduler();
@@ -811,7 +1012,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
// GEMM loopK-1
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = mIter % 2;
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
@@ -819,7 +1020,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor_ping(number<AwarpIter>{}), b_warp_tensor_ping(nIter)(kIter));
@@ -851,15 +1052,15 @@ struct FlatmmPipelineAGmemBGmemCRegV1
__builtin_amdgcn_sched_barrier(0x7F6);
});
// preload next A from lds
if constexpr((kIter!=KIterPerWarp-1)||(mIter<(MIterPerWarp-2)))
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + 2) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + 2) / MIterPerWarp);
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor_ping(number<AwarpIter>{}) = load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
//barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == (MIterPerWarp - 2)))
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}
@@ -869,8 +1070,10 @@ struct FlatmmPipelineAGmemBGmemCRegV1
TailHotLoopScheduler();
static_for<0, 2, 1>{}([&](auto mIter) {
a_warp_tensor_pong(mIter) = load_tile(a_warp_windows_pong(mIter)(number<0>{}));
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MIterPerWarp;
constexpr auto kIter = loadIter / MIterPerWarp;
a_warp_tensor_pong(loadIter) = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
});
// __builtin_amdgcn_sched_barrier(0);
@@ -878,7 +1081,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
// GEMM loopK
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = mIter % 2;
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
@@ -886,7 +1089,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor_pong(number<AwarpIter>{}), b_warp_tensor_pong(nIter)(kIter));
@@ -897,10 +1100,10 @@ struct FlatmmPipelineAGmemBGmemCRegV1
c_warp_tensor.get_thread_buffer());
__builtin_amdgcn_sched_barrier(0x7F6);
});
if constexpr((kIter!=KIterPerWarp-1)||(mIter<(MIterPerWarp-2)))
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + 2) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + 2) / MIterPerWarp);
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor_pong(number<AwarpIter>{}) = load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
}
});
@@ -915,7 +1118,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
// GEMM loopK
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
constexpr auto AwarpIter = mIter % 2;
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
@@ -923,7 +1126,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor_ping(number<AwarpIter>{}), b_warp_tensor_ping(nIter)(kIter));
@@ -955,15 +1158,15 @@ struct FlatmmPipelineAGmemBGmemCRegV1
__builtin_amdgcn_sched_barrier(0x7F6);
});
// preload next A from lds
if constexpr((kIter!=KIterPerWarp-1)||(mIter<(MIterPerWarp-2)))
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
{
constexpr auto AmIter = (mIter + 2) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + 2) / MIterPerWarp);
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
a_warp_tensor_ping(number<AwarpIter>{}) = load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
}
//barrier
if constexpr((kIter == KIterPerWarp - 1) && (mIter == (MIterPerWarp - 2)))
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
{
block_sync_lds();
}