diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 479249b990..4a528831ad 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -35,6 +35,225 @@ namespace ck { namespace tensor_operation { namespace device { +// Helper function to dispatch split-K hack for standard kernel (single LDS) +template +__device__ void DispatchSplitKHack(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const typename GridwiseGemm::Argument& karg, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + index_t k_id, + index_t k_batch, + bool split_k_offset_a_hack, + bool split_k_offset_b_hack) +{ + if(split_k_offset_a_hack && split_k_offset_b_hack) + { + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } + else if(split_k_offset_a_hack) + { + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } + else if(split_k_offset_b_hack) + { + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } + else + { + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } +} + +// Helper function to dispatch split-K hack for 2lds kernel +template +__device__ void DispatchSplitKHack_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const typename GridwiseGemm::Argument& karg, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + index_t k_id, + index_t k_batch, + bool split_k_offset_a_hack, + bool split_k_offset_b_hack) +{ + if(split_k_offset_a_hack && split_k_offset_b_hack) + { + GridwiseGemm::template Run_2Lds(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } + else if(split_k_offset_a_hack) + { + GridwiseGemm::template Run_2Lds(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } + else if(split_k_offset_b_hack) + { + GridwiseGemm::template Run_2Lds(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } + else + { + GridwiseGemm::template Run_2Lds(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } +} + template (karg.p_a_grid + a_batch_offset + split_k_offset_a, - karg.p_b_grid + b_batch_offset + split_k_offset_b, - karg.p_c_grid + e_batch_offset, - p_shared, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx * num_k_per_block, - gridDim.y, - split_k_offset_a_hack, - split_k_offset_b_hack); + DispatchSplitKHack(karg.p_a_grid + a_batch_offset + split_k_offset_a, + karg.p_b_grid + b_batch_offset + split_k_offset_b, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx * num_k_per_block, + gridDim.y, + split_k_offset_a_hack, + split_k_offset_b_hack); } #else ignore = karg; @@ -156,24 +376,25 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset + split_k_offset_a, - karg.p_b_grid + b_batch_offset + split_k_offset_b, - karg.p_c_grid + e_batch_offset, - p_shared_0, - p_shared_1, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx * num_k_per_block, - gridDim.y, - split_k_offset_a_hack, - split_k_offset_b_hack); + DispatchSplitKHack_2Lds(karg.p_a_grid + a_batch_offset + split_k_offset_a, + karg.p_b_grid + b_batch_offset + split_k_offset_b, + karg.p_c_grid + e_batch_offset, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx * num_k_per_block, + gridDim.y, + split_k_offset_a_hack, + split_k_offset_b_hack); } #else ignore = karg; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index fb82df5e31..15fade412e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -34,6 +34,225 @@ namespace ck { namespace tensor_operation { namespace device { +// Helper function to dispatch split-K hack for standard kernel (single LDS) +template +__device__ void DispatchSplitKHack(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const typename GridwiseGemm::Argument& karg, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + index_t k_id, + index_t k_batch, + bool split_k_offset_a_hack, + bool split_k_offset_b_hack) +{ + if(split_k_offset_a_hack && split_k_offset_b_hack) + { + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } + else if(split_k_offset_a_hack) + { + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } + else if(split_k_offset_b_hack) + { + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } + else + { + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } +} + +// Helper function to dispatch split-K hack for 2lds kernel +template +__device__ void DispatchSplitKHack_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const typename GridwiseGemm::Argument& karg, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + index_t k_id, + index_t k_batch, + bool split_k_offset_a_hack, + bool split_k_offset_b_hack) +{ + if(split_k_offset_a_hack && split_k_offset_b_hack) + { + GridwiseGemm::template Run_2Lds(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } + else if(split_k_offset_a_hack) + { + GridwiseGemm::template Run_2Lds(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } + else if(split_k_offset_b_hack) + { + GridwiseGemm::template Run_2Lds(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } + else + { + GridwiseGemm::template Run_2Lds(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id, + k_batch); + } +} + template (karg.p_a_grid + a_batch_offset + split_k_offset_a, - karg.p_b_grid + b_batch_offset + split_k_offset_b, - karg.p_c_grid + e_batch_offset, - p_shared, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx * num_k_per_block, - gridDim.y, - split_k_offset_a_hack, - split_k_offset_b_hack); + + DispatchSplitKHack(karg.p_a_grid + a_batch_offset + split_k_offset_a, + karg.p_b_grid + b_batch_offset + split_k_offset_b, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx * num_k_per_block, + gridDim.y, + split_k_offset_a_hack, + split_k_offset_b_hack); } #else ignore = karg; @@ -158,24 +379,25 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset + split_k_offset_a, - karg.p_b_grid + b_batch_offset + split_k_offset_b, - karg.p_c_grid + e_batch_offset, - p_shared_0, - p_shared_1, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx * num_k_per_block, - gridDim.y, - split_k_offset_a_hack, - split_k_offset_b_hack); + DispatchSplitKHack_2Lds(karg.p_a_grid + a_batch_offset + split_k_offset_a, + karg.p_b_grid + b_batch_offset + split_k_offset_b, + karg.p_c_grid + e_batch_offset, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx * num_k_per_block, + gridDim.y, + split_k_offset_a_hack, + split_k_offset_b_hack); } #else ignore = karg; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index 6299ba22d7..f42d4adb1a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -663,7 +663,9 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation, - TailNumber TailNum = TailNumber::Odd> + TailNumber TailNum = TailNumber::Odd, + bool SplitKOffsetAHack = false, + bool SplitKOffsetBHack = false> __device__ static void Run(const ADataType* p_a_grid, const BDataType* p_b_grid, CDataType* p_c_grid, @@ -673,13 +675,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& c_grid_desc_mblock_mperblock_nblock_nperblock, - const index_t k_id = 0, - const index_t k_batch = 1, - const bool split_k_offset_a_hack = false, - const bool split_k_offset_b_hack = false) + const index_t k_id = 0, + const index_t k_batch = 1) { - const long_index_t a_space_size_divisor = split_k_offset_a_hack ? k_batch : 1; - const long_index_t b_space_size_divisor = split_k_offset_b_hack ? k_batch : 1; + const long_index_t a_space_size_divisor = SplitKOffsetAHack ? k_batch : 1; + const long_index_t b_space_size_divisor = SplitKOffsetBHack ? k_batch : 1; const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize() / a_space_size_divisor); @@ -750,7 +750,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 true, BlockwiseGemmPipe::GlobalBufferNum>( a_grid_desc_ak0_m_ak1, - make_multi_index(split_k_offset_a_hack ? 0 : k_id, m_block_data_idx_on_grid, 0), + make_multi_index(SplitKOffsetAHack ? 0 : k_id, m_block_data_idx_on_grid, 0), a_element_op, a_block_desc_ak0_m_ak1, make_multi_index(0, 0, 0), @@ -781,7 +781,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 true, BlockwiseGemmPipe::GlobalBufferNum>( b_grid_desc_bk0_n_bk1, - make_multi_index(split_k_offset_b_hack ? 0 : k_id, n_block_data_idx_on_grid, 0), + make_multi_index(SplitKOffsetBHack ? 0 : k_id, n_block_data_idx_on_grid, 0), b_element_op, b_block_desc_bk0_n_bk1, make_multi_index(0, 0, 0), @@ -1030,7 +1030,9 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation, - TailNumber TailNum = TailNumber::Odd> + TailNumber TailNum = TailNumber::Odd, + bool SplitKOffsetAHack = false, + bool SplitKOffsetBHack = false> __device__ static void Run_2Lds(const ADataType* p_a_grid, const BDataType* p_b_grid, CDataType* p_c_grid, @@ -1041,13 +1043,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& c_grid_desc_mblock_mperblock_nblock_nperblock, - const index_t k_id = 0, - const index_t k_batch = 1, - const bool split_k_offset_a_hack = false, - const bool split_k_offset_b_hack = false) + const index_t k_id = 0, + const index_t k_batch = 1) { - const long_index_t a_space_size_divisor = split_k_offset_a_hack ? k_batch : 1; - const long_index_t b_space_size_divisor = split_k_offset_b_hack ? k_batch : 1; + const long_index_t a_space_size_divisor = SplitKOffsetAHack ? k_batch : 1; + const long_index_t b_space_size_divisor = SplitKOffsetBHack ? k_batch : 1; const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize() / a_space_size_divisor); @@ -1118,7 +1118,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 true, BlockwiseGemmPipe::GlobalBufferNum>( a_grid_desc_ak0_m_ak1, - make_multi_index(split_k_offset_a_hack ? 0 : k_id, m_block_data_idx_on_grid, 0), + make_multi_index(SplitKOffsetAHack ? 0 : k_id, m_block_data_idx_on_grid, 0), a_element_op, a_block_desc_ak0_m_ak1, make_multi_index(0, 0, 0), @@ -1149,7 +1149,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 true, BlockwiseGemmPipe::GlobalBufferNum>( b_grid_desc_bk0_n_bk1, - make_multi_index(split_k_offset_b_hack ? 0 : k_id, n_block_data_idx_on_grid, 0), + make_multi_index(SplitKOffsetBHack ? 0 : k_id, n_block_data_idx_on_grid, 0), b_element_op, b_block_desc_bk0_n_bk1, make_multi_index(0, 0, 0),