From 782ba2c3b930d63260ecbe53d4bd78582fb06b66 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Tue, 17 Oct 2023 09:24:02 -0500 Subject: [PATCH] added ab_elementwise_op support into splitK Gemm (#956) * add ab_elementwise * fixed ci * fixed a merge issue * fixed pr comments * fixed a conflict * remove 61_example --------- Co-authored-by: Jing Zhang [ROCm/composable_kernel commit: bf0addb5753fb44e39e33e0edfa2158d6f1ffce7] --- .../impl/device_gemm_xdl_splitk_c_shuffle.hpp | 102 +++++++++++++++--- .../gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp | 27 +++-- 2 files changed, 104 insertions(+), 25 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp index d20b008b88..b0193857c8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp @@ -127,7 +127,50 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK; - using Argument = typename GridwiseGemm::Argument; + struct Argument : public GridwiseGemm::Argument + { + Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t MPadded_, + index_t NPadded_, + index_t KPadded_, + index_t K0_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : GridwiseGemm::Argument(p_a_grid_, + p_b_grid_, + p_c_grid_, + M_, + N_, + K_, + StrideA_, + StrideB_, + StrideC_, + MPadded_, + NPadded_, + KPadded_, + K0_, + k_batch_), + a_element_op(a_element_op_), + b_element_op(b_element_op_), + c_element_op(c_element_op_) + { + } + + AElementwiseOperation a_element_op; + BElementwiseOperation b_element_op; + CElementwiseOperation c_element_op; + }; + using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; // Invoker @@ -168,8 +211,17 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK(karg), + b2c_map, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); }; if(has_main_k0_block_loop) @@ -180,7 +232,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK; + DefaultBlock2CTileMap, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation>; Run(kernel); } @@ -190,7 +245,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK; + DefaultBlock2CTileMap, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation>; Run(kernel); } @@ -203,7 +261,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK; + DefaultBlock2CTileMap, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation>; Run(kernel); } @@ -213,7 +274,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK; + DefaultBlock2CTileMap, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation>; Run(kernel); } @@ -261,12 +325,12 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK(static_cast(p_a), @@ -312,7 +379,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK + typename Block2CTileMap, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg, - const Block2CTileMap& b2c_map) + const Block2CTileMap& b2c_map, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) @@ -37,10 +43,13 @@ __global__ void __shared__ uint8_t p_shared[shared_size]; GridwiseGemm::template Run( - karg, static_cast(p_shared), b2c_map); + karg, static_cast(p_shared), b2c_map, a_element_op, b_element_op, c_element_op); #else ignore = karg; ignore = b2c_map; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } @@ -577,7 +586,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 typename Block2CTileMap> __device__ static void Run(const Argument& karg, void* __restrict__ p_shared_block, - const Block2CTileMap& block_2_ctile_map) + const Block2CTileMap& block_2_ctile_map, + const AElementwiseOperation a_element_op = AElementwiseOperation{}, + const BElementwiseOperation b_element_op = BElementwiseOperation{}, + const CElementwiseOperation c_element_op = CElementwiseOperation{}) { const FloatA* p_a_grid = karg.p_a_grid; const FloatB* p_b_grid = karg.p_b_grid; @@ -590,9 +602,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); - const AElementwiseOperation a_element_op = AElementwiseOperation{}; - const BElementwiseOperation b_element_op = BElementwiseOperation{}; - const CElementwiseOperation c_element_op = CElementwiseOperation{}; const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); @@ -761,8 +770,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - ComputeType, - ComputeType, + ComputeType, // ComputeType A + ComputeType, // ComputeType B FloatAcc, decltype(a_k0_m_k1_block_desc), decltype(b_k0_n_k1_block_desc),