From 2d4c5129cedda3d1b3184ca117f676c2cc06d954 Mon Sep 17 00:00:00 2001 From: Bartlomiej Kocot Date: Fri, 13 Jun 2025 09:55:51 +0000 Subject: [PATCH] Optimize clamp --- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 11 ++- ..._conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 6 +- ...d_multiple_d_xdl_large_tensor_cshuffle.hpp | 5 +- .../gridwise_gemm_multiple_d_xdl_cshuffle.hpp | 95 +++++++++++------- .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 96 ++++++++++++------- 5 files changed, 140 insertions(+), 73 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 27da1d91a3..6d04835b21 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -311,8 +311,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle static_assert(NumGroupsToMerge >= 1); - static constexpr bool isMultiA = is_detected::value; - static constexpr bool isMultiB = is_detected::value; + static constexpr bool isMultiA = is_detected::value; + static constexpr bool isMultiB = is_detected::value; + static constexpr bool isMultiAB = isMultiA || isMultiB; // NGCHW is not supported for multiAB static_assert(!(is_NGCHW_NGKHW() || @@ -323,6 +324,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle static constexpr index_t NumBTensor = GetNumABTensors(); static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr bool DoElementwiseBeforeCShuffle = + NumDTensor == 0 && !isMultiAB && is_same_v && + !is_same_v; + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -465,7 +470,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ - BComputeDataType + BComputeDataType, DoElementwiseBeforeCShuffle // Use appropriate gridwise gemm using GridwiseGemm = std::conditional_t< isMultiA || isMultiB, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 714025a421..48424c16b9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -279,6 +279,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 static constexpr bool isMultiD = DsDataType::Size() > 0; static constexpr bool isMultiABD = isMultiA || isMultiB || isMultiD; + static constexpr bool DoElementwiseBeforeCShuffle = + !isMultiABD && is_same_v && + !is_same_v; + static constexpr index_t NumATensor = GetNumABTensors(); static constexpr index_t NumBTensor = GetNumABTensors(); static constexpr index_t NumDTensor = DsDataType::Size(); @@ -412,7 +416,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, \ - AComputeDataType, BComputeDataType + AComputeDataType, BComputeDataType, false, false, DoElementwiseBeforeCShuffle // Use appropriate gridwise gemm using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp index 94a4e0da4c..9988367959 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -192,6 +192,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t MaxGemmsNum = 32; + static constexpr bool DoElementwiseBeforeCShuffle = + NumDTensor == 0 && is_same_v && + !is_same_v; static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -361,7 +364,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ - AComputeDataType + AComputeDataType, DoElementwiseBeforeCShuffle // Use appropriate gridwise gemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index be0fff087e..acbccf1889 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -71,11 +71,13 @@ template + PipelineVersion PipelineVer = PipelineVersion::v1, + typename BComputeDataType_ = AComputeDataType_, + bool DoElementwiseBeforeCShuffle = false> struct GridwiseGemmMultipleD_xdl_cshuffle { static constexpr index_t NumDTensor = DsDataType::Size(); + static_assert(!DoElementwiseBeforeCShuffle || NumDTensor == 0); using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; @@ -796,37 +798,60 @@ struct GridwiseGemmMultipleD_xdl_cshuffle n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( make_multi_index(n_thread_data_on_block)); + tensor_operation::element_wise::PassThrough pass_through{}; + const auto& vpgr_to_lds_element_op = [&] { + if constexpr(DoElementwiseBeforeCShuffle) + { + return cde_element_op; + } + else + { + return pass_through; + } + }; + const auto& lds_to_global_element_op = [&] { + if constexpr(!DoElementwiseBeforeCShuffle) + { + return cde_element_op; + } + else + { + return pass_through; + } + }; + // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + conditional_t, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + vpgr_to_lds_element_op()}; // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( @@ -860,7 +885,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle Tuple, decltype(c_ds_desc_refs), decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CDEElementwiseOperation, + conditional_t, Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence // support arbitray type Sequence<1, @@ -881,7 +908,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle idx_c_ds_block_begin, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), - cde_element_op}; + lds_to_global_element_op()}; // space filling curve for threadwise C in VGPR before shuffle constexpr auto sfc_c_vgpr = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index 57742f783c..6270d0c4dc 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -186,6 +186,8 @@ __global__ void /// in global memory. Currently not supported! /// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout /// in global memory (pre-shuffled). +/// @tparam DoElementwiseBeforeCShuffle Whether the cde_elementwise should be performed before or +/// after elementwise op. template + bool PermuteB = false, + bool DoElementwiseBeforeCShuffle = false> struct GridwiseGemm_xdl_cshuffle_v3 { static constexpr auto I0 = Number<0>{}; @@ -1615,42 +1618,67 @@ struct GridwiseGemm_xdl_cshuffle_v3 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( make_multi_index(n_thread_data_on_block)); + tensor_operation::element_wise::PassThrough pass_through{}; + const auto& vpgr_to_lds_element_op = [&] { + if constexpr(DoElementwiseBeforeCShuffle) + { + return problem.c_element_op_; + } + else + { + return pass_through; + } + }; + const auto& lds_to_global_element_op = [&] { + if constexpr(!DoElementwiseBeforeCShuffle) + { + return problem.c_element_op_; + } + else + { + return pass_through; + } + }; + // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + conditional_t, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + vpgr_to_lds_element_op()}; // shuffle: blockwise copy C from LDS to global auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, + ThisThreadBlock, // ThreadGroup + conditional_t, CGlobalMemoryDataOperation, // DstInMemOp, Sequence<1, CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, @@ -1671,7 +1699,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 make_multi_index(0, 0, 0, 0), c_grid_desc_mblock_mperblock_nblock_nperblock, make_multi_index(block_m_id, 0, block_n_id, 0), - problem.c_element_op_}; + lds_to_global_element_op()}; // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr =