diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp index ed34468c58..49fa6676cd 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp @@ -10,7 +10,7 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" #include "ck/host_utility/device_prop.hpp" @@ -21,13 +21,14 @@ namespace ck { namespace tensor_operation { namespace device { -/// @brief \"Universal\" GEMM operation with SplitK support. +/// @brief \"Universal\" GEMM operation with SplitK support and multiple D tensors. /// /// @par Overview /// This GEMM operation implements the following mathematical equation: -/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N})) -/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are -/// elementwise operations applied to the A, B, and C tensors, respectively. +/// E{M,N} = CDE_op(A_op(A{M,K}) * B_op(B{K,N}), Ds{M,N}...) +/// Where A, B, Ds are input tensors and E is the output tensor. The A/B are elementwise +// operations that could be applied on each tensor respectively. The CDE_op is an +// elementwise operation applied to the C and all D tensors. /// The \"universal\" gemm comes with multiple pipelines optimized for different usage /// scenarios. That's why it's called \"universal\". It's universal through it's design /// and versatilty. @@ -39,18 +40,20 @@ namespace device { /// /// @tparam ALayout A tensor data layout. /// @tparam BLayout B tensor data layout. -/// @tparam CLayout C tensor data layout. +/// @tparam DsLayout D tensors data layouts. +/// @tparam ELayout E tensor data layout. /// @tparam ADataType A tensor data type. /// @tparam BDataType B tensor data type. -/// @tparam CDataType C tensor data type. +/// @tparam DsDataType D tensors data types. +/// @tparam EDataType E tensor data type. /// @tparam AccDataType The accumulation data type related to the hardware /// matrix-multiplication instruction. /// @tparam CShuffleDataType The data type used to store matrix-multiplication results into /// LDS memory during \"CShuffle\" data layout optimization. /// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements. /// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements. -/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor -/// (after GEMM). +/// @tparam CDEElementwiseOperation Elementwise operation applied to the C output tensor (after +/// GEMM) and D input tensors. /// @tparam GemmSpec Determines used "padding" version. /// @tparam BlockSize The number of threads within workgroup. /// @tparam MPerBlock The input/output data tile size in the M dimension. @@ -104,11 +107,12 @@ namespace device { /// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions /// results to process per wave per iteration of CShuffle /// in N dimension. -/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial +/// @tparam CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial /// thread distribution used for storing data into output /// tensor across output data layout dimensions. -/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access. -/// Used when storing data to output tensor. +/// @tparam CDEShuffleBlockTransferScalarPerVectors The size of vectorized memory access. +/// Used when loading data from D tensors and storing data +/// to output tensor. /// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or /// intrawave). /// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline. @@ -122,15 +126,17 @@ namespace device { /// in global memory (pre-shuffled). template -struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2 +struct DeviceGemmMultipleD_Wmma_CShuffleV3 + : public DeviceGemmMultipleDSplitK { - // GridwiseGemm + static constexpr index_t NumDTensor = DsDataType::Size(); + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< ALayout, BLayout, - Tuple<>, // DsLayout - CLayout, + DsLayout, + ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, - Tuple<>, // DsDataType - CDataType, + DsDataType, + EDataType, AElementwiseOperation, BElementwiseOperation, - CElementwiseOperation, + CDEElementwiseOperation, GemmSpec, BlockSize, MPerBlock, @@ -220,8 +230,8 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, @@ -285,8 +295,21 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2 rotating_mem( - arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); + const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N( + arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs); + + std::array size_ds_buffers; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + size_ds_buffers[i] = + ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); + }); + ck::utility::RotatingMemWrapperMultiD rotating_mem( + arg_, + stream_config.rotating_count, + size_a_buffer, + size_b_buffer, + size_ds_buffers); rotating_mem.Print(); auto run_flush_cache = [&]() { @@ -298,7 +321,7 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2 1) HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_e_grid, 0, - arg_.M * arg_.N * sizeof(CDataType), + arg_.M * arg_.N * sizeof(EDataType), stream_config.stream_id_)); }; @@ -316,7 +339,7 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2 1) HIP_CHECK_ERROR(hipMemsetAsync(arg.p_e_grid, 0, - arg.M * arg.N * sizeof(CDataType), + arg.M * arg.N * sizeof(EDataType), stream_config.stream_id_)); ave_time = launch_and_time_kernel( @@ -419,8 +442,8 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2 || - std::is_same_v) + if constexpr(std::is_same_v || + std::is_same_v) { if(arg.KBatch > 1 && ck::is_gfx11_supported()) { @@ -455,36 +478,33 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2(p_arg)); } - index_t GetKPerBlock() override { return KPerBlock; } - - bool GetPermuteA() override { return PermuteA; } - bool GetPermuteB() override { return PermuteB; } - - static auto MakeArgument(const ADataType* p_a, - const BDataType* p_b, - CDataType* p_c, + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, - index_t StrideC, + std::array StrideDs, + index_t StrideE, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation cde_element_op) + CDEElementwiseOperation cde_element_op) { - return Argument{p_a, - p_b, - std::array{}, // p_ds_grid_ - p_c, + return Argument{static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_e), M, N, K, StrideA, StrideB, - std::array{}, // StrideDs_ - StrideC, + StrideDs, + StrideE, KBatch, a_element_op, b_element_op, @@ -494,35 +514,38 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2 MakeArgumentPointer(const void* p_a, - const void* p_b, - void* p_c, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - index_t KBatch, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) override + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override { return std::make_unique(static_cast(p_a), static_cast(p_b), - std::array{}, // p_ds_grid_ - static_cast(p_c), + p_ds, + static_cast(p_e), M, N, K, StrideA, StrideB, - std::array{}, // StrideDs_ - StrideC, + StrideDs, + StrideE, KBatch, a_element_op, b_element_op, - c_element_op); + cde_element_op); } // polymorphic @@ -548,12 +571,17 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2{}([&](auto i) { + using DLayout = remove_cvref_t>; + + str << std::string(DLayout::name)[0]; + }); + str << std::string(ELayout::name)[0] << ">" << " BlkSize: " << BlockSize << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp index ed34468c58..40628c487e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp @@ -176,7 +176,6 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2 { - // GridwiseGemm using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< ALayout, BLayout, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index 3a24385ef4..9dce3f22a3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -64,8 +64,9 @@ __global__ void /// @par Overview /// This GEMM kernel is carrying out following mathematical equation: /// E{M,N} = CDE_op(A_op(A{M,K}) * B_op(B{K,N}), Ds{M,N}...) -/// Where A, B, Ds are input tensors and E is the output tensor. The A/B/CDE_op are -/// elementwise operations that could be applied on each tensor respectively. +/// Where A, B, Ds are input tensors and E is the output tensor. The A/B are elementwise +// operations that could be applied on each tensor respectively. The CDE_op is an +// elementwise operation applied to the C and all D tensors. /// The \"universal\" gemm comes with multiple pipelines optimized for different usage /// scenarios. That's why it's called \"universal\". It's universal through it's design /// and versatilty. @@ -77,7 +78,7 @@ __global__ void /// /// @tparam ALayout A tensor data layout. /// @tparam BLayout B tensor data layout. -/// @tparam DsLayout D tensors data layout. +/// @tparam DsLayout D tensors data layouts. /// @tparam ELayout E tensor data layout. /// @tparam ADataType A tensor data type. /// @tparam BDataType B tensor data type. @@ -85,6 +86,7 @@ __global__ void /// matrix-multiplication instruction. /// @tparam CShuffleDataType The data type used to store matrix-multiplication results into /// LDS memory during \"CShuffle\" data layout optimization. +/// @tparam DsDataType D tensors data types. /// @tparam EDataType E tensor data type. /// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements. /// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements. @@ -542,7 +544,7 @@ struct GridwiseGemm_wmma_cshuffle_v3 } template - __device__ static auto + __host__ __device__ static auto MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE) { const auto c_grid_desc_mraw_nraw = [&]() { @@ -620,7 +622,7 @@ struct GridwiseGemm_wmma_cshuffle_v3 using DsGridPointer = decltype(MakeDsGridPointer()); - __device__ static auto MakeDsGridDescriptor_M_N( + __host__ __device__ static auto MakeDsGridDescriptor_M_N( index_t M, index_t MPad, index_t N, index_t NPad, std::array StrideDs) { return generate_tuple( @@ -747,9 +749,9 @@ struct GridwiseGemm_wmma_cshuffle_v3 is_reduce(is_reduce_) { static_for<0, NumDTensor, 1>{}([&](auto i) { - using DDataType_ = remove_cvref_t>; + using DDataType = remove_cvref_t>; - p_ds_grid(i) = static_cast(p_ds_grid_[i]); + p_ds_grid(i) = static_cast(p_ds_grid_[i]); }); }