mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
1.fix loop_num=odd bug 2.optimize mi300 performance of big MNK(tilesize 128x128x128) 3.optimize decode perf on mi300
This commit is contained in:
@@ -145,8 +145,8 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
|
||||
}
|
||||
};
|
||||
|
||||
if(has_hot_loop)
|
||||
{
|
||||
// if(has_hot_loop)
|
||||
// {
|
||||
if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
@@ -165,28 +165,28 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Num K loop must be larger than number of prefetech stages."
|
||||
<< "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// if(tail_num == ck_tile::TailNumber::Odd)
|
||||
// {
|
||||
// RunSplitk(ck_tile::bool_constant<false>{},
|
||||
// ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
// }
|
||||
// else if(tail_num == ck_tile::TailNumber::Even)
|
||||
// {
|
||||
// RunSplitk(ck_tile::bool_constant<false>{},
|
||||
// ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// std::ostringstream err;
|
||||
// err << "Num K loop must be larger than number of prefetech stages."
|
||||
// << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
// << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
// throw std::runtime_error(err.str());
|
||||
// }
|
||||
// }
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
@@ -161,9 +161,9 @@ struct GemmConfig
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
#elif defined(USING_MFMA_16x16x32_F8) //MI300 FP8 16X16
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
|
||||
@@ -96,6 +96,39 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
|
||||
/*
|
||||
defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) // mi300 fp8 16c 0.5*K1
|
||||
defined(USING_MFMA_32x32x16) && defined(ENABLE_FP8) // mi300 fp8 32c 0.5*K1
|
||||
defined(USING_MFMA_16x16x16) && defined(ENABLE_FP16) // mi300 fp16 16c 0.5*K1
|
||||
defined(USING_MFMA_32x32x8) && defined(ENABLE_FP16) // mi300 fp16 32c 0.5*K1
|
||||
|
||||
defined(USING_MFMA_16x16x128) && defined(ENABLE_FP8) // mi350 fp8 32c 2*K1
|
||||
defined(USING_MFMA_32x32x64) && defined(ENABLE_FP8) // mi350 fp8 64c 2*K1
|
||||
defined(USING_MFMA_16x16x32) && defined(ENABLE_FP16) // mi350 fp16 16c 1*K1
|
||||
defined(USING_MFMA_32x32x16) && defined(ENABLE_FP16) // mi350 fp16 32c 1*K1
|
||||
|
||||
defined(USING_MFMA_16x16x128) && defined(ENABLE_FP4) // mi350 fp4 16c 1*K1
|
||||
defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
|
||||
*/
|
||||
|
||||
#if (defined(USING_MFMA_16x16x32_F8) || \
|
||||
defined(USING_MFMA_32x32x16_F8) || \
|
||||
defined(USING_MFMA_16x16x16_F16) || \
|
||||
defined(USING_MFMA_32x32x8_F16)) // K1 per Mfma = 0.5
|
||||
static constexpr auto mfma_per_wg = 2;
|
||||
static constexpr auto dsread_per_wg = 1;
|
||||
#elif (defined(USING_MFMA_16x16x32_F16) || \
|
||||
defined(USING_MFMA_32x32x16_F16) || \
|
||||
defined(USING_MFMA_16x16x128_F4) || \
|
||||
defined(USING_MFMA_32x32x64_F4)) // K1 per Mfma = 1
|
||||
static constexpr auto mfma_per_wg = 1;
|
||||
static constexpr auto dsread_per_wg = 1;
|
||||
#elif (defined(USING_MFMA_16x16x128_F8) || \
|
||||
defined(USING_MFMA_32x32x64_F8)) // K1 per Mfma = 2
|
||||
static constexpr auto mfma_per_wg = 1;
|
||||
static constexpr auto dsread_per_wg = 2;
|
||||
#endif
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
@@ -191,7 +224,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
// 0 M7N2: 63 - - 8 -
|
||||
// 0 M7N3: 64 4 - - -
|
||||
|
||||
#if 1
|
||||
#if 0 // MI350 FP8 16X16 128*256*256
|
||||
static_for<0, 2, 1>{}([&](auto j) {
|
||||
ignore = j;
|
||||
static_for<0, 3, 1>{}([&](auto i) {
|
||||
@@ -237,7 +270,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
#endif
|
||||
#if 0
|
||||
#if 0 // MI350 FP8 16X16
|
||||
static_for<0, 2, 1>{}([&](auto j) {
|
||||
ignore = j;
|
||||
static_for<0, 3, 1>{}([&](auto i) {
|
||||
@@ -273,6 +306,166 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
#endif
|
||||
#if 0 // MI300 FP8 16X16 128*128*128
|
||||
static_for<0, 2, 1>{}([&](auto j) {
|
||||
ignore = j;
|
||||
static_for<0, 2, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
static_for<0, 2, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
static_for<0, 1, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
static_for<0, 1, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
static_for<0, 1, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
static_for<0, 1, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
#endif
|
||||
#if 0 // MI300 FP8 16X16 128*256*128
|
||||
static_for<0, 2, 1>{}([&](auto j) {
|
||||
ignore = j;
|
||||
static_for<0, 4, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
static_for<0, 1, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
|
||||
static_for<0, 1, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
static_for<0, 1, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
static_for<0, 1, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
#endif
|
||||
#if 0 //MI300 FP8 16X16 16*64*256
|
||||
static_for<0, 1, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -340,6 +533,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
@@ -565,11 +759,14 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
block_sync_lds();
|
||||
|
||||
// preload A00,A10 from lds
|
||||
statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))), 2> a_warp_tensor_ping;
|
||||
statically_indexed_array<decltype(load_tile(a_warp_windows_pong(number<0>{})(number<0>{}))), 2> a_warp_tensor_pong;
|
||||
|
||||
static_for<0, 2, 1>{}([&](auto mIter) {
|
||||
a_warp_tensor_ping(mIter) = load_tile(a_warp_windows_ping(mIter)(number<0>{}));
|
||||
constexpr auto m_preload = (MIterPerWarp * KIterPerWarp >= 2) ? 2: 1;
|
||||
statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))), m_preload> a_warp_tensor_ping;
|
||||
statically_indexed_array<decltype(load_tile(a_warp_windows_pong(number<0>{})(number<0>{}))), m_preload> a_warp_tensor_pong;
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor_ping(loadIter) = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -583,7 +780,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
// }
|
||||
|
||||
|
||||
index_t iCounter = num_loop / 2 - 1;
|
||||
index_t iCounter = (num_loop - 1) / 2;
|
||||
// if constexpr(HasMainLoop)
|
||||
// {
|
||||
while(iCounter > 0)
|
||||
@@ -614,7 +811,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
// GEMM 2i
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = mIter % 2;
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
@@ -661,15 +858,15 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter != KIterPerWarp - 1) || (mIter < (MIterPerWarp - 2)))
|
||||
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + 2) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + 2) / MIterPerWarp);
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor_ping(number<AwarpIter>{}) = load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
//barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == (MIterPerWarp - 2)))
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
@@ -680,8 +877,10 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
static_for<0, 2, 1>{}([&](auto mIter) {
|
||||
a_warp_tensor_pong(mIter) = load_tile(a_warp_windows_pong(mIter)(number<0>{}));
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor_pong(loadIter) = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
HotLoopScheduler();
|
||||
|
||||
@@ -713,7 +912,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
// GEMM 2i+1
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = mIter % 2;
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
@@ -759,15 +958,15 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter!=KIterPerWarp-1)||(mIter<(MIterPerWarp-2)))
|
||||
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + 2) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + 2) / MIterPerWarp);
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor_pong(number<AwarpIter>{}) = load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
//barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == (MIterPerWarp - 2)))
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
@@ -778,8 +977,10 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
static_for<0, 2, 1>{}([&](auto mIter) {
|
||||
a_warp_tensor_ping(mIter) = load_tile(a_warp_windows_ping(mIter)(number<0>{}));
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor_ping(loadIter) = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
HotLoopScheduler();
|
||||
|
||||
@@ -811,7 +1012,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
// GEMM loopK-1
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = mIter % 2;
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
@@ -819,7 +1020,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor_ping(number<AwarpIter>{}), b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
@@ -851,15 +1052,15 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter!=KIterPerWarp-1)||(mIter<(MIterPerWarp-2)))
|
||||
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + 2) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + 2) / MIterPerWarp);
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor_ping(number<AwarpIter>{}) = load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
//barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == (MIterPerWarp - 2)))
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
@@ -869,8 +1070,10 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
|
||||
TailHotLoopScheduler();
|
||||
|
||||
static_for<0, 2, 1>{}([&](auto mIter) {
|
||||
a_warp_tensor_pong(mIter) = load_tile(a_warp_windows_pong(mIter)(number<0>{}));
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MIterPerWarp;
|
||||
constexpr auto kIter = loadIter / MIterPerWarp;
|
||||
a_warp_tensor_pong(loadIter) = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
});
|
||||
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
@@ -878,7 +1081,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
// GEMM loopK
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = mIter % 2;
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
@@ -886,7 +1089,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor_pong(number<AwarpIter>{}), b_warp_tensor_pong(nIter)(kIter));
|
||||
|
||||
@@ -897,10 +1100,10 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
if constexpr((kIter!=KIterPerWarp-1)||(mIter<(MIterPerWarp-2)))
|
||||
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + 2) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + 2) / MIterPerWarp);
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor_pong(number<AwarpIter>{}) = load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
});
|
||||
@@ -915,7 +1118,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
// GEMM loopK
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = mIter % 2;
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
@@ -923,7 +1126,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor_ping(number<AwarpIter>{}), b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
@@ -955,15 +1158,15 @@ struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter!=KIterPerWarp-1)||(mIter<(MIterPerWarp-2)))
|
||||
if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + 2) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + 2) / MIterPerWarp);
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor_ping(number<AwarpIter>{}) = load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
//barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == (MIterPerWarp - 2)))
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user