From 3bb718ad5a9c77208228df904fecaa2a47073181 Mon Sep 17 00:00:00 2001 From: valarLip <340077269@qq.com> Date: Wed, 6 Nov 2024 18:25:18 +0800 Subject: [PATCH] update pipeline_gemm0 --- .../core/arch/amd_buffer_addressing.hpp | 5 + include/ck_tile/core/arch/arch.hpp | 18 +++ .../fused_moegemm_pipeline_flatmm.hpp | 131 +++++++++++------- 3 files changed, 107 insertions(+), 47 deletions(-) diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 718f634170..fed8da77d5 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -640,6 +640,11 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0) asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); } +CK_TILE_DEVICE void lds_load_fence(index_t cnt = 0) +{ + asm volatile("s_waitcnt lgkmcnt(%0)" : : "n"(cnt) : "memory"); +} + template struct buffer_atomic_add_if; diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 65a3a4e2ff..afcf982a63 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -73,6 +73,24 @@ CK_TILE_DEVICE void block_sync_lds() #endif } +CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0) +{ +#ifdef __gfx12__ + asm volatile("s_wait_loadcnt %0 \n" + "s_barrier_signal -1 \n" + "s_barrier_wait -1" + : + : "n"(cnt) + : "memory"); +#else + asm volatile("s_waitcnt vmcnt(%0) \n" + "s_barrier" + : + : "n"(cnt) + : "memory"); +#endif +} + CK_TILE_DEVICE void block_sync_lds_direct_load() { asm volatile("\ diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp index a3428fdabd..186e27d26b 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp @@ -260,9 +260,9 @@ struct FusedMoeGemmPipeline_Flatmm { async_load_tile_raw(a_store_, a_win, i_access, PreNop{}); }; - // auto move_a = [&]() { - // move_tile_window(a_win, {number<0>{}, number{}}); - // }; + auto move_a = [&]() { + move_tile_window(a_win, {number<0>{}, number{}}); + }; auto sld_a = [&](auto& a_, auto& win_, auto i_access) { load_tile_raw(a_, win_, i_access); }; @@ -284,11 +284,11 @@ struct FusedMoeGemmPipeline_Flatmm } load_tile_raw(g_, g_win, i_access, FALSE, PreNop{}); }; - // auto move_g = - // [&]() { - // move_tile_window(g_win, - // {number<0>{}, number{}, number<0>{}}); - // }; + auto move_g = + [&]() { + move_tile_window(g_win, + {number<0>{}, number{}, number<0>{}}); + }; statically_indexed_array ds; auto gld_d = [&]>( @@ -296,10 +296,10 @@ struct FusedMoeGemmPipeline_Flatmm { load_tile_raw(d_, d_win, i_access, FALSE, PreNop{}); }; - // auto move_d = [&]() { - // // d move along gemm-n - // move_tile_window(d_win, {number{}, number<0>{}}); - // }; + auto move_d = [&]() { + // d move along gemm-n + move_tile_window(d_win, {number{}, number<0>{}}); + }; auto atomic_add_o = [&]>( auto& o_, auto i_access, PreNop = {}) @@ -427,53 +427,66 @@ struct FusedMoeGemmPipeline_Flatmm // mfma(that can reuse the B matrix) only affected by M repeat. auto pipeline_gemm0 = [&]() { constexpr index_t total_loops = issues_gemm0; - constexpr index_t mfma_per_gld_g = total_loops / issues_g; // BlockShape::Repeat_M0; - constexpr index_t mfma_per_gld_a = total_loops / issues_a; - constexpr index_t mfma_per_sld_a = total_loops / issues_sld_a; + constexpr index_t mfma_per_ld = total_loops / (issues_g + issues_a + issues_sld_a); // compute buffer 0 static_for<0, total_loops, 1>{}([&](auto i_issue) { gemm_0(acc_0, as[I0], gs[I0], i_issue); - if constexpr(i_issue % mfma_per_gld_g == 0) + + if constexpr(i_issue % mfma_per_ld == 0) { - gld_g(gs[I1], number{}); - move_g(); + constexpr index_t ld_id = 0; + + if constexpr(ld_id < issues_g) + { + gld_g(gs[I0], number{}); + } + if constexpr(ld_id - issues_g < +issues_a) + { + gld_a(a_sst_win0, number{}); + } + if constexpr(ld_id - issues_g - issues_a < issues_sld_a) + { + sld_a(as[I1], a_sld_win1, number{}); + } + + ld_id++; } - if constexpr(i_issue % mfma_per_gld_a == 0) - { - gld_a(a_sst_win0, number{}); - move_a(); - } - - if constexpr(i_issue % mfma_per_sld_a == 0) - { - block_sync_lds(); - sld_a(as[I1], a_sld_win1, number{}); - } }); + move_g(); + move_a(); + block_sync_load_raw(issues_a + issues_g); + lds_load_fence(); // compute buffer 1 static_for<0, total_loops, 1>{}([&](auto i_issue) { gemm_0(acc_0, as[I1], gs[I1], i_issue); - if constexpr(i_issue % mfma_per_gld_g == 0) - { - gld_g(gs[I0], number{}); - move_g(); - } - if constexpr(i_issue % mfma_per_gld_a == 0) + if constexpr(i_issue % mfma_per_ld == 0) { - gld_a(a_sst_win1, number{}); - move_a(); - } + constexpr index_t ld_id = 0; - if constexpr(i_issue % mfma_per_sld_a == 0) - { - block_sync_lds(); - sld_a(as[I0], a_sld_win0, number{}); + if constexpr(ld_id < issues_g) + { + gld_g(gs[I1], number{}); + } + if constexpr(ld_id - issues_g < +issues_a) + { + gld_a(a_sst_win1, number{}); + } + if constexpr(ld_id - issues_g - issues_a < issues_sld_a) + { + sld_a(as[I0], a_sld_win0, number{}); + } + + ld_id++; } }); + move_g(); + move_a(); + block_sync_load_raw(issues_a + issues_g); + lds_load_fence(); }; auto pipeline_gemm0_tail = [&]() { @@ -486,14 +499,23 @@ struct FusedMoeGemmPipeline_Flatmm static_for<0, total_loops, 1>{}([&](auto i_issue) { gemm_0(acc_0, as[I0], gs[I0], i_issue); if constexpr(i_issue % mfma_per_gld_g == 0) + { gld_g(gs[I1], number{}); + move_g(); + } // if constexpr (i_issue % mfma_per_gld_a == 0) // gld_a(a_sst_win0, number{}); - if constexpr(i_issue % mfma_per_sld_a == 0) - sld_a(as[I1], a_sld_win1, number{}); + // if constexpr(i_issue % mfma_per_sld_a == 0) + // { + // block_sync_load_raw(a_sst_win0.get_num_of_access()); + // sld_a(as[I1], a_sld_win1, number{}); + // } }); + // if cycle_mfma>gld_a sync here + block_sync_load_raw(issues_g); + sld_a(as[I1], a_sld_win1, NEG1{}); // compute buffer 1 static_for<0, total_loops, 1>{}([&](auto i_issue) { @@ -523,7 +545,10 @@ struct FusedMoeGemmPipeline_Flatmm static_for<0, total_loops, 1>{}([&](auto i_issue) { gemm_1(acc_1s[I1], y, ds[I1], i_issue); if constexpr(i_issue % mfma_per_gld_d == 0) + { gld_d(ds[I0], number{}); + move_d(); + } if constexpr(i_issue % mfma_per_atm_o == 0) { @@ -536,7 +561,10 @@ struct FusedMoeGemmPipeline_Flatmm static_for<0, total_loops, 1>{}([&](auto i_issue) { gemm_1(acc_1s[I0], y, ds[I0], i_issue); if constexpr(i_issue % mfma_per_gld_d == 0) + { gld_d(ds[I1], number{}); + move_d(); + } if constexpr(i_issue % mfma_per_atm_o == 0) { @@ -553,7 +581,10 @@ struct FusedMoeGemmPipeline_Flatmm static_for<0, total_loops, 1>{}([&](auto i_issue) { gemm_1(acc_1s[I0], y, ds[I0], i_issue); if constexpr(i_issue % mfma_per_gld_d == 0) + { gld_d(ds[I1], number{}); + move_d(); + } }); }; auto pipeline_gemm1_tail = [&]() { @@ -564,7 +595,10 @@ struct FusedMoeGemmPipeline_Flatmm static_for<0, total_loops, 1>{}([&](auto i_issue) { gemm_1(acc_1s[I1], y, ds[I1], i_issue); if constexpr(i_issue % mfma_per_gld_d == 0) + { gld_d(ds[I0], number{}); + move_d(); + } if constexpr(i_issue % mfma_per_atm_o == 0) { @@ -586,10 +620,13 @@ struct FusedMoeGemmPipeline_Flatmm move_g(); clear_tile(acc_0); - async_load_fence_raw(g_win.get_num_of_access()); - sld_a(as[I0], a_sld_win0, NEG1); - gld_a(a_sst_win1, NEG1); + // preload for next round + gld_a(a_sst_win1, NEG1); + gld_g(gs[I1], NEG1); + // make sure a,g loaded + block_sync_load_raw(issues_a + issues_g); + lds_load_fence(); // we manually unroll double buffer inside hot loop const index_t iters_0 = (num_blocks_k0 - 2) / 2;