mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
update pipeline_gemm0
This commit is contained in:
@@ -640,6 +640,11 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
|
||||
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void lds_load_fence(index_t cnt = 0)
|
||||
{
|
||||
asm volatile("s_waitcnt lgkmcnt(%0)" : : "n"(cnt) : "memory");
|
||||
}
|
||||
|
||||
template <typename scalar_type, index_t N, bool pre_nop = false>
|
||||
struct buffer_atomic_add_if;
|
||||
|
||||
|
||||
@@ -73,6 +73,24 @@ CK_TILE_DEVICE void block_sync_lds()
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
asm volatile("s_wait_loadcnt %0 \n"
|
||||
"s_barrier_signal -1 \n"
|
||||
"s_barrier_wait -1"
|
||||
:
|
||||
: "n"(cnt)
|
||||
: "memory");
|
||||
#else
|
||||
asm volatile("s_waitcnt vmcnt(%0) \n"
|
||||
"s_barrier"
|
||||
:
|
||||
: "n"(cnt)
|
||||
: "memory");
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void block_sync_lds_direct_load()
|
||||
{
|
||||
asm volatile("\
|
||||
|
||||
@@ -260,9 +260,9 @@ struct FusedMoeGemmPipeline_Flatmm
|
||||
{
|
||||
async_load_tile_raw(a_store_, a_win, i_access, PreNop{});
|
||||
};
|
||||
// auto move_a = [&]() {
|
||||
// move_tile_window(a_win, {number<0>{}, number<BlockShape::Block_K0>{}});
|
||||
// };
|
||||
auto move_a = [&]() {
|
||||
move_tile_window(a_win, {number<0>{}, number<BlockShape::Block_K0>{}});
|
||||
};
|
||||
auto sld_a = [&](auto& a_, auto& win_, auto i_access) {
|
||||
load_tile_raw(a_, win_, i_access);
|
||||
};
|
||||
@@ -284,11 +284,11 @@ struct FusedMoeGemmPipeline_Flatmm
|
||||
}
|
||||
load_tile_raw(g_, g_win, i_access, FALSE, PreNop{});
|
||||
};
|
||||
// auto move_g =
|
||||
// [&]() {
|
||||
// move_tile_window(g_win,
|
||||
// {number<0>{}, number<BlockShape::Block_Kr0>{}, number<0>{}});
|
||||
// };
|
||||
auto move_g =
|
||||
[&]() {
|
||||
move_tile_window(g_win,
|
||||
{number<0>{}, number<BlockShape::Block_Kr0>{}, number<0>{}});
|
||||
};
|
||||
statically_indexed_array<d_thread_type, 2> ds;
|
||||
|
||||
auto gld_d = [&]<typename PreNop = bool_constant<false>>(
|
||||
@@ -296,10 +296,10 @@ struct FusedMoeGemmPipeline_Flatmm
|
||||
{
|
||||
load_tile_raw(d_, d_win, i_access, FALSE, PreNop{});
|
||||
};
|
||||
// auto move_d = [&]() {
|
||||
// // d move along gemm-n
|
||||
// move_tile_window(d_win, {number<BlockShape::Block_N1>{}, number<0>{}});
|
||||
// };
|
||||
auto move_d = [&]() {
|
||||
// d move along gemm-n
|
||||
move_tile_window(d_win, {number<BlockShape::Block_N1>{}, number<0>{}});
|
||||
};
|
||||
|
||||
auto atomic_add_o = [&]<typename PreNop = bool_constant<false>>(
|
||||
auto& o_, auto i_access, PreNop = {})
|
||||
@@ -427,53 +427,66 @@ struct FusedMoeGemmPipeline_Flatmm
|
||||
// mfma(that can reuse the B matrix) only affected by M repeat.
|
||||
auto pipeline_gemm0 = [&]() {
|
||||
constexpr index_t total_loops = issues_gemm0;
|
||||
constexpr index_t mfma_per_gld_g = total_loops / issues_g; // BlockShape::Repeat_M0;
|
||||
constexpr index_t mfma_per_gld_a = total_loops / issues_a;
|
||||
constexpr index_t mfma_per_sld_a = total_loops / issues_sld_a;
|
||||
constexpr index_t mfma_per_ld = total_loops / (issues_g + issues_a + issues_sld_a);
|
||||
|
||||
// compute buffer 0
|
||||
static_for<0, total_loops, 1>{}([&](auto i_issue) {
|
||||
gemm_0(acc_0, as[I0], gs[I0], i_issue);
|
||||
if constexpr(i_issue % mfma_per_gld_g == 0)
|
||||
|
||||
if constexpr(i_issue % mfma_per_ld == 0)
|
||||
{
|
||||
gld_g(gs[I1], number<i_issue / mfma_per_gld_g>{});
|
||||
move_g();
|
||||
constexpr index_t ld_id = 0;
|
||||
|
||||
if constexpr(ld_id < issues_g)
|
||||
{
|
||||
gld_g(gs[I0], number<ld_id>{});
|
||||
}
|
||||
if constexpr(ld_id - issues_g < +issues_a)
|
||||
{
|
||||
gld_a(a_sst_win0, number<ld_id - issues_g>{});
|
||||
}
|
||||
if constexpr(ld_id - issues_g - issues_a < issues_sld_a)
|
||||
{
|
||||
sld_a(as[I1], a_sld_win1, number<ld_id - issues_g - issues_a>{});
|
||||
}
|
||||
|
||||
ld_id++;
|
||||
}
|
||||
|
||||
if constexpr(i_issue % mfma_per_gld_a == 0)
|
||||
{
|
||||
gld_a(a_sst_win0, number<i_issue / mfma_per_gld_a>{});
|
||||
move_a();
|
||||
}
|
||||
|
||||
if constexpr(i_issue % mfma_per_sld_a == 0)
|
||||
{
|
||||
block_sync_lds();
|
||||
sld_a(as[I1], a_sld_win1, number<i_issue / mfma_per_sld_a>{});
|
||||
}
|
||||
});
|
||||
move_g();
|
||||
move_a();
|
||||
block_sync_load_raw(issues_a + issues_g);
|
||||
lds_load_fence();
|
||||
|
||||
// compute buffer 1
|
||||
static_for<0, total_loops, 1>{}([&](auto i_issue) {
|
||||
gemm_0(acc_0, as[I1], gs[I1], i_issue);
|
||||
if constexpr(i_issue % mfma_per_gld_g == 0)
|
||||
{
|
||||
gld_g(gs[I0], number<i_issue / mfma_per_gld_g>{});
|
||||
move_g();
|
||||
}
|
||||
|
||||
if constexpr(i_issue % mfma_per_gld_a == 0)
|
||||
if constexpr(i_issue % mfma_per_ld == 0)
|
||||
{
|
||||
gld_a(a_sst_win1, number<i_issue / mfma_per_gld_a>{});
|
||||
move_a();
|
||||
}
|
||||
constexpr index_t ld_id = 0;
|
||||
|
||||
if constexpr(i_issue % mfma_per_sld_a == 0)
|
||||
{
|
||||
block_sync_lds();
|
||||
sld_a(as[I0], a_sld_win0, number<i_issue / mfma_per_sld_a>{});
|
||||
if constexpr(ld_id < issues_g)
|
||||
{
|
||||
gld_g(gs[I1], number<ld_id>{});
|
||||
}
|
||||
if constexpr(ld_id - issues_g < +issues_a)
|
||||
{
|
||||
gld_a(a_sst_win1, number<ld_id - issues_g>{});
|
||||
}
|
||||
if constexpr(ld_id - issues_g - issues_a < issues_sld_a)
|
||||
{
|
||||
sld_a(as[I0], a_sld_win0, number<ld_id - issues_g - issues_a>{});
|
||||
}
|
||||
|
||||
ld_id++;
|
||||
}
|
||||
});
|
||||
move_g();
|
||||
move_a();
|
||||
block_sync_load_raw(issues_a + issues_g);
|
||||
lds_load_fence();
|
||||
};
|
||||
|
||||
auto pipeline_gemm0_tail = [&]() {
|
||||
@@ -486,14 +499,23 @@ struct FusedMoeGemmPipeline_Flatmm
|
||||
static_for<0, total_loops, 1>{}([&](auto i_issue) {
|
||||
gemm_0(acc_0, as[I0], gs[I0], i_issue);
|
||||
if constexpr(i_issue % mfma_per_gld_g == 0)
|
||||
{
|
||||
gld_g(gs[I1], number<i_issue / mfma_per_gld_g>{});
|
||||
move_g();
|
||||
}
|
||||
|
||||
// if constexpr (i_issue % mfma_per_gld_a == 0)
|
||||
// gld_a(a_sst_win0, number<i_issue / mfma_per_gld_a>{});
|
||||
|
||||
if constexpr(i_issue % mfma_per_sld_a == 0)
|
||||
sld_a(as[I1], a_sld_win1, number<i_issue / mfma_per_sld_a>{});
|
||||
// if constexpr(i_issue % mfma_per_sld_a == 0)
|
||||
// {
|
||||
// block_sync_load_raw(a_sst_win0.get_num_of_access());
|
||||
// sld_a(as[I1], a_sld_win1, number<i_issue / mfma_per_sld_a>{});
|
||||
// }
|
||||
});
|
||||
// if cycle_mfma>gld_a sync here
|
||||
block_sync_load_raw(issues_g);
|
||||
sld_a(as[I1], a_sld_win1, NEG1{});
|
||||
|
||||
// compute buffer 1
|
||||
static_for<0, total_loops, 1>{}([&](auto i_issue) {
|
||||
@@ -523,7 +545,10 @@ struct FusedMoeGemmPipeline_Flatmm
|
||||
static_for<0, total_loops, 1>{}([&](auto i_issue) {
|
||||
gemm_1(acc_1s[I1], y, ds[I1], i_issue);
|
||||
if constexpr(i_issue % mfma_per_gld_d == 0)
|
||||
{
|
||||
gld_d(ds[I0], number<i_issue / mfma_per_gld_d>{});
|
||||
move_d();
|
||||
}
|
||||
|
||||
if constexpr(i_issue % mfma_per_atm_o == 0)
|
||||
{
|
||||
@@ -536,7 +561,10 @@ struct FusedMoeGemmPipeline_Flatmm
|
||||
static_for<0, total_loops, 1>{}([&](auto i_issue) {
|
||||
gemm_1(acc_1s[I0], y, ds[I0], i_issue);
|
||||
if constexpr(i_issue % mfma_per_gld_d == 0)
|
||||
{
|
||||
gld_d(ds[I1], number<i_issue / mfma_per_gld_d>{});
|
||||
move_d();
|
||||
}
|
||||
|
||||
if constexpr(i_issue % mfma_per_atm_o == 0)
|
||||
{
|
||||
@@ -553,7 +581,10 @@ struct FusedMoeGemmPipeline_Flatmm
|
||||
static_for<0, total_loops, 1>{}([&](auto i_issue) {
|
||||
gemm_1(acc_1s[I0], y, ds[I0], i_issue);
|
||||
if constexpr(i_issue % mfma_per_gld_d == 0)
|
||||
{
|
||||
gld_d(ds[I1], number<i_issue / mfma_per_gld_d>{});
|
||||
move_d();
|
||||
}
|
||||
});
|
||||
};
|
||||
auto pipeline_gemm1_tail = [&]() {
|
||||
@@ -564,7 +595,10 @@ struct FusedMoeGemmPipeline_Flatmm
|
||||
static_for<0, total_loops, 1>{}([&](auto i_issue) {
|
||||
gemm_1(acc_1s[I1], y, ds[I1], i_issue);
|
||||
if constexpr(i_issue % mfma_per_gld_d == 0)
|
||||
{
|
||||
gld_d(ds[I0], number<i_issue / mfma_per_gld_d>{});
|
||||
move_d();
|
||||
}
|
||||
|
||||
if constexpr(i_issue % mfma_per_atm_o == 0)
|
||||
{
|
||||
@@ -586,10 +620,13 @@ struct FusedMoeGemmPipeline_Flatmm
|
||||
move_g();
|
||||
clear_tile(acc_0);
|
||||
|
||||
async_load_fence_raw(g_win.get_num_of_access());
|
||||
sld_a(as[I0], a_sld_win0, NEG1);
|
||||
gld_a(a_sst_win1, NEG1);
|
||||
// preload for next round
|
||||
gld_a(a_sst_win1, NEG1);
|
||||
gld_g(gs[I1], NEG1);
|
||||
|
||||
// make sure a,g loaded
|
||||
block_sync_load_raw(issues_a + issues_g);
|
||||
lds_load_fence();
|
||||
|
||||
// we manually unroll double buffer inside hot loop
|
||||
const index_t iters_0 = (num_blocks_k0 - 2) / 2;
|
||||
|
||||
Reference in New Issue
Block a user