diff --git a/include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp index 9ed8311315..53271f0802 100644 --- a/include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp @@ -9,6 +9,8 @@ #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" #include "gridwise_gemm_xdl_cshuffle_v1.hpp" +#include "binary_element_wise_operation.hpp" +#include "gridwise_binary_elementwise_1d.hpp" #include "tensor_operation/gpu/device/gemm_specialization.hpp" namespace ck { @@ -66,6 +68,41 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; + static constexpr auto ScalarPerVector = Number<4>{}; + + template + static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t threadPerBlock) + { + const auto m0 = desc_m0.GetLength(I0); + const index_t loop_step = gridSize * threadPerBlock * ScalarPerVector; + const auto pad = math::integer_least_multiple(m0, loop_step) - m0; + const auto desc_m0_pad = + transform_tensor_descriptor(desc_m0, + make_tuple(make_right_pad_transform(m0, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m0_pad; + } + + static auto MakeDescriptor_M0(const std::vector& shape, + const std::vector& stride, + index_t gridSize, + index_t threadPerBlock) + { + auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number<2>{}); + auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<2>{}); + + const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); + + const auto desc_m0 = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(tupleOfShape)), + make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<2>{})), + make_tuple(Sequence<0>{})); + + return PadDescriptor_M0_1d(desc_m0, gridSize, threadPerBlock); + } + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) { const auto a_grid_desc_mraw_kraw = [&]() { @@ -333,6 +370,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1)); // GridwiseGemm using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< @@ -426,6 +464,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_); } + + const index_t grid_size = GridwiseGemm::CalculateGridSize(c_grid_desc_m_n_); + + if constexpr(is_same::value) + { + c_grid_desc_m0_ = + DeviceOp::MakeDescriptor_M0({MRaw, NRaw}, {StrideC, I1}, grid_size, BlockSize); + } + else if constexpr(is_same::value) + { + c_grid_desc_m0_ = + DeviceOp::MakeDescriptor_M0({MRaw, NRaw}, {I1, StrideC}, grid_size, BlockSize); + } } // private: @@ -440,6 +491,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; CGridDesc_M_N c_grid_desc_m_n_; + GridDesc_M0 c_grid_desc_m0_; typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; @@ -468,6 +520,35 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle float ave_time = 0; + using Add = ck::tensor_operation::binary_element_wise::Add; + using Substract = ck::tensor_operation::binary_element_wise::Substract; + using GridwiseBinAdd = GridwiseBinaryElementwise_1D; + using GridwiseBinSubstract = GridwiseBinaryElementwise_1D; + const auto add_kernel = kernel_elementwise_1d; + const auto substract_kernel = kernel_elementwise_1d; + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { const auto kernel = kernel_gemm_xdl_cshuffle_v1< @@ -517,7 +598,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.block_2_ctile_map_); - // c_real = aux - aux_2 needed here!!! + // c_real = aux - aux_2 + ave_time += launch_and_time_kernel(stream_config, + substract_kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_aux_grid_, + arg.p_aux_2_grid_, + arg.p_c_grid_real_, + arg.c_grid_desc_m0_, + arg.c_grid_desc_m0_, + arg.c_grid_desc_m0_, + Substract{}); ave_time += launch_and_time_kernel(stream_config, @@ -553,7 +646,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.block_2_ctile_map_); - // c_imag = aux + aux_2 needed here!!! + // c_imag = aux + aux_2 + ave_time += launch_and_time_kernel(stream_config, + add_kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_aux_grid_, + arg.p_aux_2_grid_, + arg.p_c_grid_imag_, + arg.c_grid_desc_m0_, + arg.c_grid_desc_m0_, + arg.c_grid_desc_m0_, + Add{}); } else { @@ -604,7 +709,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.block_2_ctile_map_); - // // c_real = aux - aux_2 needed here!!! + // c_real = aux - aux_2 + ave_time += launch_and_time_kernel(stream_config, + substract_kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_aux_grid_, + arg.p_aux_2_grid_, + arg.p_c_grid_real_, + arg.c_grid_desc_m0_, + arg.c_grid_desc_m0_, + arg.c_grid_desc_m0_, + Substract{}); ave_time += launch_and_time_kernel(stream_config, @@ -640,7 +757,19 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.block_2_ctile_map_); - // c_imag = aux + aux_2 needed here!!! + // c_imag = aux + aux_2 + ave_time += launch_and_time_kernel(stream_config, + add_kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_aux_grid_, + arg.p_aux_2_grid_, + arg.p_c_grid_imag_, + arg.c_grid_desc_m0_, + arg.c_grid_desc_m0_, + arg.c_grid_desc_m0_, + Add{}); } return ave_time; diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index ebec6b5b50..d6c113213a 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -12,6 +12,39 @@ struct Add { dst = src1 + src2; } + + __host__ __device__ constexpr void + operator()(half_t& dst, const half_t& src1, const half_t& src2) const + { + dst = src1 + src2; + } + + __host__ __device__ constexpr void + operator()(bhalf_t& dst, const bhalf_t& src1, const bhalf_t& src2) const + { + dst = src1 + src2; + } +}; + +struct Substract +{ + __host__ __device__ constexpr void + operator()(float& dst, const float& src1, const float& src2) const + { + dst = src1 - src2; + } + + __host__ __device__ constexpr void + operator()(half_t& dst, const half_t& src1, const half_t& src2) const + { + dst = src1 - src2; + } + + __host__ __device__ constexpr void + operator()(bhalf_t& dst, const bhalf_t& src1, const bhalf_t& src2) const + { + dst = src1 - src2; + } }; } // namespace binary_element_wise