From 18707866d906ab79980ea8f2695e56bcb8cd4d77 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sun, 10 Apr 2022 03:01:58 +0000 Subject: [PATCH] adding thread group --- .../device/device_gemm_xdl_cshuffle_v2.hpp | 77 ++++++------ .../gpu/grid/gridwise_gemm_pipeline_v2.hpp | 92 ++++++++++---- .../grid/gridwise_gemm_xdl_cshuffle_v2.hpp | 118 ++++++++++++++---- 3 files changed, 206 insertions(+), 81 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle_v2.hpp index ba943763c6..58ad77f180 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle_v2.hpp @@ -27,7 +27,8 @@ template -struct GridwiseGemmPipeline_v2{}; static constexpr auto I1 = Number<1>{}; - static __device__ void RunProducer(const AGridDesc& a_grid_desc, - const ABlockDesc& a_block_desc, - ABlockTransfer& a_blockwise_copy, - const AGridBuffer& a_grid_buf, - ABlockBuffer& a_block_buf, - const ABlockTransferStep& a_block_copy_step, - const BGridDesc& b_grid_desc, - const BBlockDesc& b_block_desc, - BBlockTransfer& b_blockwise_copy, - const BGridBuffer& b_grid_buf, - BBlockBuffer& b_block_buf, - const BBlockTransferStep& b_block_copy_step, - index_t num_loop) + __device__ constexpr GridwiseGemmPipeline_v2() + { + // TODO static assert + } + + static __device__ void RunABBlockTransferPipeline(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + index_t num_loop) { // global read 0 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); @@ -140,11 +151,11 @@ struct GridwiseGemmPipeline_v2{}; static constexpr auto BK1 = Number{}; + using ThisThreadBlock = + AnyThreadBlock; + +#if 1 + using ABBlockTransferThreadGroup = ThisThreadBlock; + using BlockGemmThreadGroup = ThisThreadBlock; + using CShuffleBlockTransferThreadGroup = ThisThreadBlock; +#else + struct ABBlockTransferThreadGroup + { + __device__ static constexpr index_t GetNumOfThread() + { + return ABBlockTransferThreadGroupSize; + } + + __device__ static constexpr bool IsBelong() + { + return get_thread_local_1d_id() < ABBlockTransferThreadGroupSize; + } + + __device__ static index_t GetThreadId() { return get_thread_local_1d_id(); } + }; + + struct BlockGemmThreadGroup + { + __device__ static constexpr index_t GetNumOfThread() + { + return ABBlockTransferThreadGroupSize; + } + + __device__ static constexpr bool IsBelong() + { + return get_thread_local_1d_id() >= ABBlockTransferThreadGroupSize; + } + + __device__ static index_t GetThreadId() + { + return get_thread_local_1d_id() - ABBlockTransferThreadGroupSize; + } + }; + + using CShuffleBlockTransferThreadGroup = ThisThreadBlock; +#endif + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { // A matrix in LDS memory, dst of blockwise copy @@ -345,11 +390,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 // B matrix in LDS memory, dst of blockwise copy constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - using ThisThreadBlock = AnyThreadBlock; - // A matrix blockwise copy auto a_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1::selected_mfma.k_per_blk); auto blockwise_gemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1, @@ -465,10 +513,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 remove_cvref_t, NumGemmKPrefetchStage, HasMainK0BlockLoop>{}; - - const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( - (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / - KPerBlock); +#else + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v2, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + NumGemmKPrefetchStage, + HasMainK0BlockLoop>{}; +#endif gridwise_gemm_pipeline.Run(a_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, @@ -601,7 +667,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 // shuffle: blockwise copy C from LDS to global auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // index_t BlockSize, + ThisThreadBlock, // ThreadGroup CElementwiseOperation, // ElementwiseOperation, CGlobalMemoryDataOperation, // DstInMemOp, Sequence<1, @@ -655,22 +721,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v2 // make sure it's safe to write to LDS block_sync_lds(); - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_shuffle_block_buf); + if(BlockGemmThreadGroup::IsBelong()) + { + // thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + } // make sure it's safe to read from LDS block_sync_lds(); - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); + if(CShuffleBlockTransferThreadGroup::IsBelong()) + { + // block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } if constexpr(access_id < num_access - 1) {