From d64030ed34b10fb9f7cdbda384bfb171ac1df72a Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 7 Feb 2025 02:50:26 +0000 Subject: [PATCH] revert blockwisegemm modification --- .../blockwise_gemm_pipeline_xdlops_v1.hpp | 339 +-------- .../blockwise_gemm_pipeline_xdlops_v2.hpp | 673 +----------------- .../blockwise_gemm_pipeline_xdlops_v3.hpp | 282 +------- .../blockwise_gemm_pipeline_xdlops_v4.hpp | 311 +------- .../blockwise_gemm_pipeline_xdlops_v5.hpp | 281 +------- .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 357 +++------- .../gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp | 217 ------ 7 files changed, 131 insertions(+), 2329 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp index b0cebac09c..f597573dc2 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp @@ -155,158 +155,6 @@ struct BlockwiseGemmXdlops_pipeline_v1 - __device__ void Run(const AGridDesc& a_grid_desc, - const ABlockDesc& a_block_desc, - ABlockTransfer& a_blockwise_copy, - const AGridBuffer& a_grid_buf, - ABlockBuffer& a_block_buf, - const ABlockTransferStep& a_block_copy_step, - const BGridDesc& b_grid_desc, - const BBlockDesc& b_block_desc, - BBlockTransfer& b_blockwise_copy, - const BGridBuffer& b_grid_buf, - BBlockBuffer& b_block_buf, - const BBlockTransferStep& b_block_copy_step, - CThreadBuffer& c_thread_buf, - AThreadBuffer& a_thread_buf_tail, - BThreadBuffer& b_thread_buf_tail, - index_t num_loop) const - { - auto a_thread_buf = make_static_buffer( - a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( - b_thread_desc_.GetElementSpaceSize()); - - // Global prefetch 1 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - // Local prefill 1 - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); - - // Initialize C - c_thread_buf.Clear(); - - // main body - if constexpr(HasMainLoop) - { - index_t i = 0; - do - { - // ------------------------------------------------------------------------------------------- - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); - }); - }); - - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - - block_sync_lds(); - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); - - i += 1; - } while(i < (num_loop - 1)); - } - - // tail - if constexpr(TailNum == TailNumber::Full) - { - block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); - }); - }); - - a_thread_buf_tail = a_thread_buf; - b_thread_buf_tail = b_thread_buf; - } - } - template - __device__ void Run(const AGridDesc& a_grid_desc, - const ABlockDesc& a_block_desc, - ABlockTransfer& a_blockwise_copy, - const AGridBuffer& a_grid_buf, - ABlockBuffer& a_block_buf, - const ABlockTransferStep& a_block_copy_step, - const BGridDesc& b_grid_desc, - const BBlockDesc& b_block_desc, - BBlockTransfer& b_blockwise_copy, - const BGridBuffer& b_grid_buf, - BBlockBuffer& b_block_buf, - const BBlockTransferStep& b_block_copy_step, - CThreadBuffer& c_thread_buf, - AThreadBuffer& a_thread_buf_tail, - BThreadBuffer& b_thread_buf_tail, - index_t num_loop) const - { - auto a_thread_buf = make_static_buffer( - a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( - b_thread_desc_.GetElementSpaceSize()); - - // Global prefetch 1 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - // Local prefill 1 - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); - - // Initialize C - c_thread_buf.Clear(); - - // main body - if constexpr(HasMainLoop) - { - index_t i = 0; - do - { - // ------------------------------------------------------------------------------------------- - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k0, I0), - a_thread_buf); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k0, I0), - b_thread_buf); - }); - }); - __builtin_amdgcn_sched_barrier(0); - // NOTE: Synchronize threads in a workgroup at the start of each MAC cluster, - // but except the first, as we can shorten non-MAC cluster a bit and there's no - // observable negative impact. The desired effect is waves in a workgroup - // executing MAC in sync. This avoids some out-of-sync waves hijacking MAC - // resource from other workgroups and reducing the chance of latency hiding by - // waiting for the rest of the workgroup at the eventual sync point. - if constexpr(k0.value != 0 || KRepeat == 1) - { - __builtin_amdgcn_s_barrier(); - __builtin_amdgcn_sched_barrier(0); - } - static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - // The block_sync_lds() here performs double duty: - // A) safeguard against data hazard because barrier from - // blockwise_gemm is moved here B) reduce VMEM FIFO congestion by - // applying small delays to different wavefronts It is performed - // near the end of MAC cluster to minimize lgkmcnt penalty - if constexpr(k0.value == KRepeat - 1 && - k_.value == KPerInnerLoop - KPack && - m0.value == MRepeat - 1 && n0.value == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } - }); - }); - }); - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(0); - __builtin_amdgcn_sched_barrier(0); - }); - - // block_sync_lds(); - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); - - i += 1; - } while(i < (num_loop - 1)); - } - - // tail - if constexpr(TailNum == TailNumber::Full) - { - block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k0, I0), - a_thread_buf); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k0, I0), - b_thread_buf); - }); - }); - }); - - a_thread_buf_tail = a_thread_buf; - b_thread_buf_tail = b_thread_buf; - } - } - template M loopover static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( make_tuple(Number{}, I1, Number{}, Number{}), diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp index 5d0f910614..0fe51d5003 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp @@ -199,281 +199,6 @@ struct BlockwiseGemmXdlops_pipeline_v2 - __device__ void Run(const AGridDesc& a_grid_desc, - const ABlockDesc& a_block_desc, - ABlockTransfer& a_blockwise_copy, - const AGridBuffer& a_grid_buf, - ABlockBuffer& a_block_buf, - const ABlockTransferStep& a_block_copy_step, - const BGridDesc& b_grid_desc, - const BBlockDesc& b_block_desc, - BBlockTransfer& b_blockwise_copy, - const BGridBuffer& b_grid_buf, - BBlockBuffer& b_block_buf, - const BBlockTransferStep& b_block_copy_step, - CThreadBuffer& c_thread_buf, - AThreadBuffer& a_thread_buf_tail, - BThreadBuffer& b_thread_buf_tail, - index_t num_loop) const - { - auto a_thread_buf = make_static_buffer( - a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( - b_thread_desc_.GetElementSpaceSize()); - - // Global prefetch 1 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - // Initialize C - c_thread_buf.Clear(); - - // Local prefill 1 - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); - - // Global prefetch [2, PrefetchStages] - static_for<1, PrefetchStages, 1>{}([&](auto iprefetch) { - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - }); - - // main body - if constexpr(HasMainLoop) - { - index_t i = 0; - do - { - static_for<0, PrefetchStages, 1>{}([&](auto iprefetch) { - // ------------------------------------------------------------------------------------------- - block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); - }); - - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - - block_sync_lds(); - a_blockwise_copy.RunWrite( - a_block_desc, a_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); - b_blockwise_copy.RunWrite( - b_block_desc, b_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); - - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - }); - - i += PrefetchStages; - } while(i < (num_loop - PrefetchStages)); - } - - // tail - - auto LoopTailFunc = [&](auto tail_num) { - static_for<1, tail_num, 1>{}([&](auto iprefetch) { - block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); - }); - - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - - block_sync_lds(); - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch); - }); - - block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); - }); - - a_thread_buf_tail = a_thread_buf; - b_thread_buf_tail = b_thread_buf; - }; - - if constexpr(TailNum == TailNumber::One) - { - block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); - }); - - a_thread_buf_tail = a_thread_buf; - b_thread_buf_tail = b_thread_buf; - } - else if constexpr(TailNum == TailNumber::Two) - { - LoopTailFunc(Number<2>{}); - } - else if constexpr(TailNum == TailNumber::Three) - { - LoopTailFunc(Number<3>{}); - } - else if constexpr(TailNum == TailNumber::Four) - { - LoopTailFunc(Number<4>{}); - } - else if constexpr(TailNum == TailNumber::Five) - { - LoopTailFunc(Number<5>{}); - } - else if constexpr(TailNum == TailNumber::Six) - { - LoopTailFunc(Number<6>{}); - } - else if constexpr(TailNum == TailNumber::Seven) - { - LoopTailFunc(Number<7>{}); - } - else if constexpr(TailNum == TailNumber::Full) - { - LoopTailFunc(Number{}); - } - } - template {}([&](auto n0) { - b_thread_copy_.Run( - b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); }); }); @@ -621,14 +345,14 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); }); }); @@ -676,14 +400,14 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); }); }); @@ -727,14 +451,14 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); - }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); }); }); @@ -796,6 +520,7 @@ struct BlockwiseGemmXdlops_pipeline_v2 - __device__ void Run(const AGridDesc& a_grid_desc, - const ABlockDesc& a_block_desc, - ABlockTransfer& a_blockwise_copy, - const AGridBuffer& a_grid_buf, - ABlockBuffer& a_block_buf, - const ABlockTransferStep& a_block_copy_step, - const BGridDesc& b_grid_desc, - const BBlockDesc& b_block_desc, - BBlockTransfer& b_blockwise_copy, - const BGridBuffer& b_grid_buf, - BBlockBuffer& b_block_buf, - const BBlockTransferStep& b_block_copy_step, - CThreadBuffer& c_thread_buf, - AThreadBuffer& a_thread_buf_tail, - BThreadBuffer& b_thread_buf_tail, - index_t num_loop) const - { - auto a_thread_buf = make_static_buffer( - a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( - b_thread_desc_.GetElementSpaceSize()); - - // Global prefetch 1 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - // Initialize C - c_thread_buf.Clear(); - - // Local prefill 1 - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); - - // Global prefetch [2, PrefetchStages] - static_for<1, PrefetchStages, 1>{}([&](auto iprefetch) { - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - }); - - // main body - if constexpr(HasMainLoop) - { - index_t i = 0; - do - { - static_for<0, PrefetchStages, 1>{}([&](auto iprefetch) { - // ------------------------------------------------------------------------------------------- - block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k0, I0), - a_thread_buf); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run( - b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k0, I0), - b_thread_buf); - }); - }); - __builtin_amdgcn_sched_barrier(0); - // NOTE: Synchronize threads in a workgroup at the start of each MAC - // cluster, but except the first, as we can shorten non-MAC cluster a bit - // and there's no observable negative impact. The desired effect is waves in - // a workgroup executing MAC in sync. This avoids some out-of-sync waves - // hijacking MAC resource from other workgroups and reducing the chance of - // latency hiding by waiting for the rest of the workgroup at the eventual - // sync point. - if constexpr(k0.value != 0 || KRepeat == 1) - { - __builtin_amdgcn_s_barrier(); - __builtin_amdgcn_sched_barrier(0); - } - static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - // The block_sync_lds() here performs double duty: - // A) safeguard against data hazard because barrier from - // blockwise_gemm is moved here B) reduce VMEM FIFO congestion - // by applying small delays to different wavefronts It is - // performed near the end of MAC cluster to minimize lgkmcnt - // penalty - if constexpr(k0.value == KRepeat - 1 && - k_.value == KPerInnerLoop - KPack && - m0.value == MRepeat - 1 && n0.value == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } - }); - }); - }); - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(0); - __builtin_amdgcn_sched_barrier(0); - }); - - // block_sync_lds(); - a_blockwise_copy.RunWrite( - a_block_desc, a_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); - b_blockwise_copy.RunWrite( - b_block_desc, b_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); - - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - }); - i += PrefetchStages; - } while(i < (num_loop - PrefetchStages)); - } - - // tail - - auto LoopTailFunc = [&](auto tail_num) { - static_for<1, tail_num, 1>{}([&](auto iprefetch) { - block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k0, I0), - a_thread_buf); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k0, I0), - b_thread_buf); - }); - }); - - __builtin_amdgcn_sched_barrier(0); - if constexpr(k0.value != 0 || KRepeat == 1) - { - __builtin_amdgcn_s_barrier(); - __builtin_amdgcn_sched_barrier(0); - } - static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - if constexpr(k0.value == KRepeat - 1 && - k_.value == KPerInnerLoop - KPack && - m0.value == MRepeat - 1 && n0.value == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } - }); - }); - }); - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(0); - __builtin_amdgcn_sched_barrier(0); - }); - - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch); - }); - block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k0, I0), - a_thread_buf); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k0, I0), - b_thread_buf); - }); - }); - }); - - a_thread_buf_tail = a_thread_buf; - b_thread_buf_tail = b_thread_buf; - }; - - if constexpr(TailNum == TailNumber::One) - { - block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k0, I0), - a_thread_buf); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k0, I0), - b_thread_buf); - }); - }); - }); - - a_thread_buf_tail = a_thread_buf; - b_thread_buf_tail = b_thread_buf; - } - else if constexpr(TailNum == TailNumber::Two) - { - LoopTailFunc(Number<2>{}); - } - else if constexpr(TailNum == TailNumber::Three) - { - LoopTailFunc(Number<3>{}); - } - else if constexpr(TailNum == TailNumber::Four) - { - LoopTailFunc(Number<4>{}); - } - else if constexpr(TailNum == TailNumber::Five) - { - LoopTailFunc(Number<5>{}); - } - else if constexpr(TailNum == TailNumber::Six) - { - LoopTailFunc(Number<6>{}); - } - else if constexpr(TailNum == TailNumber::Seven) - { - LoopTailFunc(Number<7>{}); - } - else if constexpr(TailNum == TailNumber::Full) - { - LoopTailFunc(Number{}); - } - } - template M loopover static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( make_tuple(Number{}, I1, Number{}, Number{}), diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp index e87616efe1..171a232c0f 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp @@ -262,227 +262,6 @@ struct BlockwiseGemmXdlops_pipeline_v3 - __device__ void Run(const AGridDesc& a_grid_desc, - const ABlockDesc& a_block_desc, - ABlockTransfer& a_blockwise_copy, - const AGridBuffer& a_grid_buf, - ABlockBuffer& a_block_buf, - const ABlockTransferStep& a_block_copy_step, - const BGridDesc& b_grid_desc, - const BBlockDesc& b_block_desc, - BBlockTransfer& b_blockwise_copy, - const BGridBuffer& b_grid_buf, - BBlockBuffer& b_block_buf, - const BBlockTransferStep& b_block_copy_step, - CThreadBuffer& c_thread_buf, - AThreadBuffer& a_thread_buf_tail, - BThreadBuffer& b_thread_buf_tail, - index_t num_loop) const - { - __builtin_amdgcn_sched_barrier(0); - auto a_thread_buf = make_static_buffer( - a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( - b_thread_desc_.GetElementSpaceSize()); - - // Global prefetch 1 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - // Local prefill 1 - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); - - // Global prefetch 2 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - // Initialize C - c_thread_buf.Clear(); - - // Local prefetch 1 - block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k0, I0), - a_thread_buf); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k0, I0), - b_thread_buf); - }); - }); - - __builtin_amdgcn_sched_barrier(0); - - // main body - if constexpr(HasMainLoop) - { - index_t i = 0; - do - { - block_sync_lds(); - - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); - - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - - block_sync_lds(); - - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k0, I0), - a_thread_buf); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k0, I0), - b_thread_buf); - }); - }); - - HotLoopScheduler(); - __builtin_amdgcn_sched_barrier(0); - - i += 1; - } while(i < (num_loop - 2)); - } - // tail - if constexpr(TailNum == TailNumber::Full) - { - - block_sync_lds(); - - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); - - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - - block_sync_lds(); - - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k0, I0), - a_thread_buf_tail); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k0, I0), - b_thread_buf_tail); - }); - }); - - HotLoopScheduler(); - __builtin_amdgcn_sched_barrier(0); - } - } - template {}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - - block_sync_lds(); - - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k0, I0), - a_thread_buf); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k0, I0), - b_thread_buf); - }); - }); - - HotLoopScheduler(); - __builtin_amdgcn_sched_barrier(0); - static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { @@ -731,6 +452,7 @@ struct BlockwiseGemmXdlops_pipeline_v3 - __device__ void Run(const AGridDesc& a_grid_desc, - const ABlockDesc& a_block_desc, - ABlockTransfer& a_blockwise_copy, - const AGridBuffer& a_grid_buf, - ABlockBuffer& a_block_buf, - const ABlockTransferStep& a_block_copy_step, - const BGridDesc& b_grid_desc, - const BBlockDesc& b_block_desc, - BBlockTransfer& b_blockwise_copy, - const BGridBuffer& b_grid_buf, - BBlockBuffer& b_block_buf, - const BBlockTransferStep& b_block_copy_step, - CThreadBuffer& c_thread_buf, - AThreadBuffer& a_thread_buf_tail, - BThreadBuffer& b_thread_buf_tail, - index_t num_loop) const - { - auto a_thread_buf = make_static_buffer( - a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( - b_thread_desc_.GetElementSpaceSize()); - - StaticallyIndexedArray{}> a_thread_bufs; - StaticallyIndexedArray{}> b_thread_bufs; - - // Global prefetch 1 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - // Local prefill 1 - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0)); - - // Local prefetch 1 - block_sync_lds(); - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_bufs(I0)); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf.At(I0), - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_bufs(I0)); - }); - }); - }); - - // Global prefetch 2 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - // Local prefill 2 - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1)); - - // Global prefetch 3 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - // Initialize C - c_thread_buf.Clear(); - - // main body - if constexpr(HasMainLoop) - { - index_t i = 0; - // This hot loop has two legacy loopover, to implement the double local buffer strategy - do - { - auto LoopFunc = [&](auto lds_read_buf, - auto lds_read_reg_buf, - auto lds_write_buf, - auto mfma_reg_buf) { - block_sync_lds(); - - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(lds_read_buf), - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_bufs(lds_read_reg_buf)); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run( - b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf.At(lds_read_buf), - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_bufs(lds_read_reg_buf)); - }); - }); - }); - - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf)); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf)); - - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg_buf] - [Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf] - [Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - - HotLoopScheduler(); - }; - - LoopFunc(I1, I1, I0, I0); - LoopFunc(I0, I0, I1, I1); - - i += HotloopUnroll; - } while(i < (num_loop - PrefetchStages)); - } - - auto ReadWriteCompFunc = [&](auto lds_read_buf, - auto lds_read_reg_buf, - auto lds_write_buf, - auto mfma_reg_buf) { - block_sync_lds(); - - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(lds_read_buf), - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_bufs(lds_read_reg_buf)); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf.At(lds_read_buf), - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_bufs(lds_read_reg_buf)); - }); - }); - }); - - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf)); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf)); - - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg_buf][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - - HotLoopScheduler(); - }; - - auto ReadCompFunc = [&](auto lds_read_buf, auto lds_read_reg_buf, auto mfma_reg_buf) { - block_sync_lds(); - - static_for<0, KRepeat, 1>{}([&](auto k) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf.At(lds_read_buf), - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_bufs(lds_read_reg_buf)); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf.At(lds_read_buf), - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_bufs(lds_read_reg_buf)); - }); - }); - }); - - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_bufs[mfma_reg_buf][Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf][Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - - HotLoopScheduler(); - }; - - auto CompFunc = [&](auto mfma_reg_buf) { - a_thread_buf_tail = a_thread_bufs[mfma_reg_buf]; - b_thread_buf_tail = b_thread_bufs[mfma_reg_buf]; - }; - - // tail - if constexpr(TailNum == TailNumber::Odd) - { - ReadWriteCompFunc(I1, I1, I0, I0); - ReadCompFunc(I0, I0, I1); - CompFunc(I0); - } - else if constexpr(TailNum == TailNumber::Even) - { - ReadCompFunc(I1, I1, I0); - CompFunc(I1); - } - } - template - __device__ void Run(const AGridDesc& a_grid_desc, - const ABlockDesc& a_block_desc, - ABlockTransfer& a_blockwise_copy, - const AGridBuffer& a_grid_buf, - ABlockBuffer& a_block_buf, - const ABlockTransferStep& a_block_copy_step, - const BGridDesc& b_grid_desc, - const BBlockDesc& b_block_desc, - BBlockTransfer& b_blockwise_copy, - const BGridBuffer& b_grid_buf, - BBlockBuffer& b_block_buf, - const BBlockTransferStep& b_block_copy_step, - CThreadBuffer& c_thread_buf, - AThreadBuffer& a_thread_buf_tail, - BThreadBuffer& b_thread_buf_tail, - index_t num_loop) const - { - auto a_thread_buf = make_static_buffer( - a_thread_desc_loop.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( - b_thread_desc_loop.GetElementSpaceSize()); - - // Global prefetch 1 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - // Local prefill 1 - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); - - // Global prefetch 2 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - // Global prefetch 3 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - // Initialize C - c_thread_buf.Clear(); - - // Local prefetch 1 - block_sync_lds(); - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_loop.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, I0), - a_block_buf, - a_thread_desc_loop, - make_tuple(m0, I0, I0, I0), - a_thread_buf); - }); - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_loop.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, I0), - b_block_buf, - b_thread_desc_loop, - make_tuple(n0, I0, I0, I0), - b_thread_buf); - }); - - // main body - if constexpr(HasMainLoop) - { - index_t i = 0; - do - { - auto LoopFunc = [&](auto vmem_buf) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KRepeat, 1>{}([&](auto k0) { - if constexpr(k0 == (KRepeat - 1)) - { - block_sync_lds(); - - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, vmem_buf); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, vmem_buf); - - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, vmem_buf); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, vmem_buf); - - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - - block_sync_lds(); - } - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - }); - static_for<0, KPack, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - - a_thread_copy_loop.Run( - a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}), - a_block_buf, - a_thread_desc_loop, - make_tuple(m0, I0, I0, I0), - a_thread_buf); - }); - - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_loop.Run( - b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}), - b_block_buf, - b_thread_desc_loop, - make_tuple(n0, I0, I0, I0), - b_thread_buf); - }); - }); - - HotLoopScheduler(); - }; - - LoopFunc(I0); - LoopFunc(I1); - - i += HotloopUnroll; - } while(i < (num_loop - PrefetchStages)); - } - // tail - auto ReadWriteCompFunc = [&](auto vmem_buf) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KRepeat, 1>{}([&](auto k0) { - if constexpr(k0 == (KRepeat - 1)) - { - block_sync_lds(); - - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, vmem_buf); - b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, vmem_buf); - - block_sync_lds(); - } - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - }); - static_for<0, KPack, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - a_thread_copy_loop.Run( - a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}), - a_block_buf, - a_thread_desc_loop, - make_tuple(m0, I0, I0, I0), - a_thread_buf); - }); - - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_loop.Run( - b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}), - b_block_buf, - b_thread_desc_loop, - make_tuple(n0, I0, I0, I0), - b_thread_buf); - }); - }); - - HotLoopScheduler(); - }; - auto ReadCompFunc = [&]() { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k0, I0), - a_thread_buf_tail); - }); - - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k0, I0), - b_thread_buf_tail); - }); - }); - - HotLoopScheduler(); - }; - - if constexpr(TailNum == TailNumber::Odd) - { - ReadWriteCompFunc(I0); - ReadWriteCompFunc(I1); - ReadCompFunc(); - } - else if constexpr(TailNum == TailNumber::Even) - { - ReadWriteCompFunc(I0); - ReadCompFunc(); - } - } - template {}, I1, I1, Number{})); // B[NRepeat, N1, N2, KPack] - static constexpr auto b_thread_desc_loop = + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1, I1, Number{})); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, Sequence<0, 1, 2, 3>, 3, @@ -912,19 +649,15 @@ struct BlockwiseGemmXdlops_pipeline_v5, Sequence<0, 1, 2, 3>, 3, B_K1, B_K1>; - AThreadCopy a_thread_copy_loop{Base::CalculateAThreadOriginDataIndex()}; - BThreadCopy b_thread_copy_loop{Base::CalculateBThreadOriginDataIndex()}; - using Base::a_thread_copy_; - using Base::a_thread_desc_; - using Base::b_thread_copy_; - using Base::b_thread_desc_; + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()}; using Base::c_thread_desc_; }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index f566e8c737..9acb505a49 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -1229,38 +1229,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 return c_grid_desc_mblock_mperblock_nblock_nperblock; } - __device__ static constexpr auto EpilogueScheduler() - { - constexpr auto epilogue_tile = MPerBlock * NPerBlock * CShuffleMXdlPerWavePerShuffle * - CShuffleNXdlPerWavePerShuffle / (MXdlPerWave * NXdlPerWave); - constexpr auto num_mfma_inst = BlockwiseGemmPipe::HotLoopInstList::C_MFMA_Inst_Num * - CShuffleMXdlPerWavePerShuffle * - CShuffleNXdlPerWavePerShuffle / (MXdlPerWave * NXdlPerWave); - constexpr auto num_ds_write_inst = - epilogue_tile / BlockSize; // DefaultMFMA, per-element write - constexpr auto num_ds_read_inst = - epilogue_tile / BlockSize / CShuffleBlockTransferScalarPerVector_NPerBlock; - constexpr auto num_buffer_store_inst = num_ds_read_inst; - - // MFMA:ds_write=1:2 - constexpr auto num_ds_write_issue = num_ds_write_inst / 2; - constexpr auto num_mfma_block_sync = (num_mfma_inst - num_ds_write_issue) / 2; - constexpr auto mfma_ds_write_rate = MXdlPerWave == 16 ? 2 : 4; - - // Hide ds_write issue latency - static_for<0, num_ds_write_issue, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, mfma_ds_write_rate, 0); // DS write - }); - // Hide block_sync + ds_read latency - __builtin_amdgcn_sched_group_barrier(0x008, num_mfma_block_sync, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, num_ds_read_inst, 0); // DS read - // Hide block_sync latency - __builtin_amdgcn_sched_group_barrier(0x008, num_mfma_block_sync, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x040, num_buffer_store_inst, 0); // VMEM write - } - // return block_id to C matrix tile idx (m0, n0) mapping // if arch = gfx942 using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; @@ -1409,14 +1377,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 static_assert(std::is_default_constructible_v); auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); - constexpr auto a_thread_desc = blockwise_gemm_pipeline.a_thread_desc_; - constexpr auto b_thread_desc = blockwise_gemm_pipeline.b_thread_desc_; - constexpr auto c_thread_desc = blockwise_gemm_pipeline.c_thread_desc_; - - auto a_thread_buf = make_static_buffer( - a_thread_desc.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( - b_thread_desc.GetElementSpaceSize()); const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / @@ -1435,21 +1395,14 @@ struct GridwiseGemm_xdl_cshuffle_v3 b_block_buf, b_block_slice_copy_step, c_thread_buf, - a_thread_buf, - b_thread_buf, num_k_block_main_loop); + // shuffle C and write out { - // Last block MFMA - auto xdlops_gemm = blockwise_gemm_pipeline.xdlops_gemm; - constexpr auto KRepeat = blockwise_gemm_pipeline.KRepeat; - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, "wrong!"); - // Shuffle - // 1. Copy data from VGPR to LDS - // 2. Copy data from LDS to VGPR + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); @@ -1457,6 +1410,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); @@ -1478,12 +1433,19 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple(make_freeze_transform(I0), - make_unmerge_transform( - make_tuple(Number{}, M1, M2, M3, M4)), - make_freeze_transform(I0), - make_unmerge_transform( - make_tuple(Number{}, N1, N2))), + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple( Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); @@ -1550,31 +1512,31 @@ struct GridwiseGemm_xdl_cshuffle_v3 // shuffle: blockwise copy C from LDS to global auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, - CElementwiseOperation, - CGlobalMemoryDataOperation, + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, Sequence<1, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, - CShuffleDataType, - CDataType, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, - 3, - CShuffleBlockTransferScalarPerVector_NPerBlock, - true, - false>{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_m_id, 0, block_n_id, 0), - c_element_op}; + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; - // SpaceFillingCurve tocombine all components - // C: VGPR to LDS + // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = SpaceFillingCurve, Sequence<0, 1, 2, 3, 4, 5, 6, 7>, @@ -1587,9 +1549,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 M4, 1>>{}; - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - // Shuffled C: VGPR to Global + // space filling curve for shuffled blockwise C in global mem constexpr auto sfc_c_global = SpaceFillingCurve, Sequence<0, 2, 1, 3>, @@ -1598,91 +1558,22 @@ struct GridwiseGemm_xdl_cshuffle_v3 1, CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - constexpr auto KPerInnerLoop = blockwise_gemm_pipeline.KPerInnerLoop; - - static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) { - static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - }); - - __builtin_amdgcn_sched_barrier(0); static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS block_sync_lds(); - if constexpr(access_id < num_access - 1) - { - constexpr auto shuffle_m0 = - sfc_c_vgpr.GetIndexTupleOfNumber(access_id + Number<1>{})[Number<0>{}]; - constexpr auto shuffle_n0 = - sfc_c_vgpr.GetIndexTupleOfNumber(access_id + Number<1>{})[Number<1>{}]; - - static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) { - static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc.CalculateOffset( - make_tuple(shuffle_m0 + m0, shuffle_n0 + n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - }); - } - + // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, sfc_c_vgpr.GetIndexTupleOfNumber(access_id), c_thread_buf, c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_shuffle_block_buf); + // make sure it's safe to read from LDS block_sync_lds(); // each block copy its data from LDS to global @@ -1696,10 +1587,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 { constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + // move on C c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - - EpilogueScheduler(); } }); } @@ -1893,15 +1783,6 @@ struct GridwiseGemm_xdl_cshuffle_v3 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); - constexpr auto a_thread_desc = blockwise_gemm_pipeline.a_thread_desc_; - constexpr auto b_thread_desc = blockwise_gemm_pipeline.b_thread_desc_; - constexpr auto c_thread_desc = blockwise_gemm_pipeline.c_thread_desc_; - - auto a_thread_buf = make_static_buffer( - a_thread_desc.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( - b_thread_desc.GetElementSpaceSize()); - const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KPerBlock); @@ -1919,20 +1800,14 @@ struct GridwiseGemm_xdl_cshuffle_v3 b_block_bufs, b_block_slice_copy_step, c_thread_buf, - a_thread_buf, - b_thread_buf, num_k_block_main_loop); - { - // Last block MFMA - auto xdlops_gemm = blockwise_gemm_pipeline.xdlops_gemm; - constexpr auto KRepeat = blockwise_gemm_pipeline.KRepeat; + // shuffle C and write out + { static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, "wrong!"); - // Shuffle - // 1. Copy data from VGPR to LDS - // 2. Copy data from LDS to VGPR + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); @@ -1940,6 +1815,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); @@ -1961,12 +1838,19 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_tuple(make_freeze_transform(I0), - make_unmerge_transform( - make_tuple(Number{}, M1, M2, M3, M4)), - make_freeze_transform(I0), - make_unmerge_transform( - make_tuple(Number{}, N1, N2))), + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple( Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); @@ -2033,31 +1917,31 @@ struct GridwiseGemm_xdl_cshuffle_v3 // shuffle: blockwise copy C from LDS to global auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, - CElementwiseOperation, - CGlobalMemoryDataOperation, + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, Sequence<1, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, - CShuffleDataType, - CDataType, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, - 3, - CShuffleBlockTransferScalarPerVector_NPerBlock, - true, - false>{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_m_id, 0, block_n_id, 0), - c_element_op}; + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; - // SpaceFillingCurve tocombine all components - // C: VGPR to LDS + // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = SpaceFillingCurve, Sequence<0, 1, 2, 3, 4, 5, 6, 7>, @@ -2070,9 +1954,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 M4, 1>>{}; - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - // Shuffled C: VGPR to Global + // space filling curve for shuffled blockwise C in global mem constexpr auto sfc_c_global = SpaceFillingCurve, Sequence<0, 2, 1, 3>, @@ -2081,90 +1963,22 @@ struct GridwiseGemm_xdl_cshuffle_v3 1, CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - constexpr auto KPerInnerLoop = blockwise_gemm_pipeline.KPerInnerLoop; - static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) { - static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - }); - - __builtin_amdgcn_sched_barrier(0); static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS block_sync_lds(); - if constexpr(access_id < num_access - 1) - { - constexpr auto shuffle_m0 = - sfc_c_vgpr.GetIndexTupleOfNumber(access_id + Number<1>{})[Number<0>{}]; - constexpr auto shuffle_n0 = - sfc_c_vgpr.GetIndexTupleOfNumber(access_id + Number<1>{})[Number<1>{}]; - - static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) { - static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc.CalculateOffset( - make_tuple(shuffle_m0 + m0, shuffle_n0 + n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - }); - } - + // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, sfc_c_vgpr.GetIndexTupleOfNumber(access_id), c_thread_buf, c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_shuffle_block_buf); + // make sure it's safe to read from LDS block_sync_lds(); // each block copy its data from LDS to global @@ -2178,10 +1992,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 { constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + // move on C c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - - EpilogueScheduler(); } }); } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index 4448f603ca..a9e73bf461 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -1220,38 +1220,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 return c_grid_desc_mblock_mperblock_nblock_nperblock; } - __device__ static constexpr auto EpilogueScheduler() - { - constexpr auto epilogue_tile = MPerBlock * NPerBlock * CShuffleMXdlPerWavePerShuffle * - CShuffleNXdlPerWavePerShuffle / (MXdlPerWave * NXdlPerWave); - constexpr auto num_mfma_inst = BlockwiseGemmPipe::HotLoopInstList::C_MFMA_Inst_Num * - CShuffleMXdlPerWavePerShuffle * - CShuffleNXdlPerWavePerShuffle / (MXdlPerWave * NXdlPerWave); - constexpr auto num_ds_write_inst = - epilogue_tile / BlockSize; // DefaultMFMA, per-element write - constexpr auto num_ds_read_inst = - epilogue_tile / BlockSize / CShuffleBlockTransferScalarPerVector_NPerBlock; - constexpr auto num_buffer_store_inst = num_ds_read_inst; - - // MFMA:ds_write=1:2 - constexpr auto num_ds_write_issue = num_ds_write_inst / 2; - constexpr auto num_mfma_block_sync = (num_mfma_inst - num_ds_write_issue) / 2; - constexpr auto mfma_ds_write_rate = MXdlPerWave == 16 ? 2 : 4; - - // Hide ds_write issue latency - static_for<0, num_ds_write_issue, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, mfma_ds_write_rate, 0); // DS write - }); - // Hide block_sync + ds_read latency - __builtin_amdgcn_sched_group_barrier(0x008, num_mfma_block_sync, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, num_ds_read_inst, 0); // DS read - // Hide block_sync latency - __builtin_amdgcn_sched_group_barrier(0x008, num_mfma_block_sync, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x040, num_buffer_store_inst, 0); // VMEM write - } - // return block_id to C matrix tile idx (m0, n0) mapping // if arch = gfx942 using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; @@ -1429,15 +1397,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); - constexpr auto a_thread_desc = blockwise_gemm_pipeline.a_thread_desc_; - constexpr auto b_thread_desc = blockwise_gemm_pipeline.b_thread_desc_; - constexpr auto c_thread_desc = blockwise_gemm_pipeline.c_thread_desc_; - - auto a_thread_buf = make_static_buffer( - a_thread_desc.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( - b_thread_desc.GetElementSpaceSize()); - const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KPerBlock); @@ -1455,16 +1414,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 b_block_buf, b_block_slice_copy_step, c_thread_buf, - a_thread_buf, - b_thread_buf, num_k_block_main_loop); // shuffle C and write out { - // Last block MFMA - auto xdlops_gemm = blockwise_gemm_pipeline.xdlops_gemm; - constexpr auto KRepeat = blockwise_gemm_pipeline.KRepeat; - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, "wrong!"); @@ -1624,9 +1577,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - // C: LDS -> VGPR - // D: Global -> VGPR - // E: =Epilogue(C, D), VGPR -> Global auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< ThisThreadBlock, decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), @@ -1685,84 +1635,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); - constexpr auto KPerInnerLoop = blockwise_gemm_pipeline.KPerInnerLoop; - - static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) { - static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - }); - - __builtin_amdgcn_sched_barrier(0); static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS block_sync_lds(); - if constexpr(access_id < num_access - 1) - { - constexpr auto shuffle_m0 = - sfc_c_vgpr.GetIndexTupleOfNumber(access_id + Number<1>{})[Number<0>{}]; - constexpr auto shuffle_n0 = - sfc_c_vgpr.GetIndexTupleOfNumber(access_id + Number<1>{})[Number<1>{}]; - - static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) { - static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc.CalculateOffset( - make_tuple(shuffle_m0 + m0, shuffle_n0 + n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - }); - } - // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, sfc_c_vgpr.GetIndexTupleOfNumber(access_id), @@ -1796,8 +1672,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 tie(e_grid_desc_mblock_mperblock_nblock_nperblock), I0, cde_lds_and_global_step); - - // EpilogueScheduler(); } }); } @@ -1990,15 +1864,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); - constexpr auto a_thread_desc = blockwise_gemm_pipeline.a_thread_desc_; - constexpr auto b_thread_desc = blockwise_gemm_pipeline.b_thread_desc_; - constexpr auto c_thread_desc = blockwise_gemm_pipeline.c_thread_desc_; - - auto a_thread_buf = make_static_buffer( - a_thread_desc.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( - b_thread_desc.GetElementSpaceSize()); - const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KPerBlock); @@ -2016,16 +1881,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 b_block_bufs, b_block_slice_copy_step, c_thread_buf, - a_thread_buf, - b_thread_buf, num_k_block_main_loop); // shuffle C and write out { - // Last block MFMA - auto xdlops_gemm = blockwise_gemm_pipeline.xdlops_gemm; - constexpr auto KRepeat = blockwise_gemm_pipeline.KRepeat; - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, "wrong!"); @@ -2243,84 +2102,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); - constexpr auto KPerInnerLoop = blockwise_gemm_pipeline.KPerInnerLoop; - - static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) { - static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - }); - - __builtin_amdgcn_sched_barrier(0); static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS block_sync_lds(); - if constexpr(access_id < num_access - 1) - { - constexpr auto shuffle_m0 = - sfc_c_vgpr.GetIndexTupleOfNumber(access_id + Number<1>{})[Number<0>{}]; - constexpr auto shuffle_n0 = - sfc_c_vgpr.GetIndexTupleOfNumber(access_id + Number<1>{})[Number<1>{}]; - - static_for<0, CShuffleMXdlPerWavePerShuffle, 1>{}([&](auto m0) { - static_for<0, CShuffleNXdlPerWavePerShuffle, 1>{}([&](auto n0) { - static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}]; - }); - - using mfma_input_type = - typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc.CalculateOffset( - make_tuple(shuffle_m0 + m0, shuffle_n0 + n0, 0)); - - xdlops_gemm.Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - }); - } - // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, sfc_c_vgpr.GetIndexTupleOfNumber(access_id), @@ -2354,8 +2139,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 tie(e_grid_desc_mblock_mperblock_nblock_nperblock), I0, cde_lds_and_global_step); - - // EpilogueScheduler(); } }); }